import copy
import os
import pprint
import random
from unittest.mock import patch
import game_2048


# === Helper functions for test system ===

def green(s): return f"\033[92m{s}\033[0m"
def red(s): return f"\033[91m{s}\033[0m"

def to_set(value):
    if value is None:
        return None
    elif not isinstance(value, (list)):
        raise TypeError("Output must be a list")
    try:
        return set(value)
    except TypeError:
        return None
    
# === Core test harness ===

def test(function, arguments, ground_truth, test_name, out_transform=None):
    try:
        output = function(*copy.deepcopy(arguments))
        if out_transform is not None:
            output = out_transform(output)
        if output == ground_truth:
            print(f"{green('Test passed:')} {function.__name__} - {test_name}")
        else:
            print(f"{red('Test failed:')} {function.__name__} - {test_name}")
            print(f"=== Arguments: ===")
            pprint.pprint(arguments, width=160)
            print(f"=== Expected: ===")
            print(ground_truth)
            print(f"=== Received: ===")
            print(output)

    except Exception as e:
        print(f"{red('Test failed:')} {function.__name__}{arguments}")
        print(f"Unhandled Exception: {repr(e)}\n")


# === TESTS START HERE ===

def test_display_grid():
    grid = [
        [2, 4, 0, 8],
        [0, 0, 64, 0],
        [2, 0, 16, 32],
        [0, 8, 0, 0]
    ]

    expected_output = (
        "+----+----+----+----+\n"
        "|    |    |    |    |\n"
        "| 2  | 4  |    | 8  |\n"
        "|    |    |    |    |\n"
        "+----+----+----+----+\n"
        "|    |    |    |    |\n"
        "|    |    | 64 |    |\n"
        "|    |    |    |    |\n"
        "+----+----+----+----+\n"
        "|    |    |    |    |\n"
        "| 2  |    | 16 | 32 |\n"
        "|    |    |    |    |\n"
        "+----+----+----+----+\n"
        "|    |    |    |    |\n"
        "|    | 8  |    |    |\n"
        "|    |    |    |    |\n"
        "+----+----+----+----+\n\n"
    )

    with patch('builtins.print') as mock_print:
        game_2048.display_grid(grid)

    # Concatenate all printed strings, adding newlines between print() calls
    printed_lines = [str(arg) for args, _ in mock_print.call_args_list for arg in args]
    printed_output = "\n".join(printed_lines) + "\n"

    if printed_output == expected_output:
        print(f"{green('Test passed:')} display_grid - Correctly prints 4x4 grid")
    else:
        print(f"{red('Test failed:')} display_grid - Incorrect console output")
        print("=== Expected ===")
        print(expected_output)
        print("=== Received ===")
        print(printed_output)


def test_init_grid():
    expected = [[0, 0, 0, 0],
                [0, 0, 0, 0],   
                [0, 0, 0, 0],
                [0, 0, 0, 0]]
    test(game_2048.init_grid, tuple(), expected, "Creates correct 4x4 grid")




def test_get_empty_positions():
    grid = [
        [2, 0, 0, 4],
        [0, 2, 2, 16],
        [0, 4, 64, 0],
        [4, 0, 32, 0]
    ]
    expected = {(0, 1), (0, 2), (1, 0), (2, 0), (2, 3), (3, 1), (3, 3)}
    test(game_2048.get_empty_positions, (grid,), expected, "Finds all empty cells", out_transform=to_set)

def test_get_user_input():
    # --- Case: invalid input followed by valid input ---
    with patch('builtins.input', side_effect=['x', '  a  ']) as mock_input, \
         patch('builtins.print') as mock_print:

        result = game_2048.get_user_input()

        # Capture printed lines
        printed_texts = [' '.join(map(str, args)) for args, _ in mock_print.call_args_list]
        printed_error = any("Invalid input, try again." in line for line in printed_texts)

        # Extract prompts without call() wrappers
        input_prompts = [args[0] for (args, _) in mock_input.call_args_list]
        correct_prompts = all(p == "Enter move (W/A/S/D/Q): " for p in input_prompts)

        # Evaluate correctness
        passed = result == 'A' and printed_error and correct_prompts

    # ---- Output (outside patch, so real print works) ----
    if passed:
        print(f"{green('Test passed:')} get_user_input - Rejects invalid then accepts valid")
    else:
        print(f"{red('Test failed:')} get_user_input")
        print(f"Returned value: {result}")
        print(f"Printed output: {printed_texts}")
        print(f"Input prompts: {input_prompts}")



def test_move():
    grid = [
        [2, 0, 2, 0],
        [4, 4, 4, 4],
        [0, 0, 0, 0],
        [0, 2, 16, 2]
    ]
    expected = [
        [4, 0, 0, 0],
        [8, 8, 0, 0],
        [0, 0, 0, 0],
        [2, 16, 2, 0]
    ]
    test(game_2048.move, (grid, 'A'), expected, "Merges correctly to the left")


    grid = [
        [2, 0, 2, 4],
        [2, 0, 2, 4],
        [4, 4, 8, 4],
        [0, 32, 8, 0]
    ]
    expected = [
        [4, 4, 4, 8],
        [4, 32, 16, 4],
        [0, 0, 0, 0],
        [0, 0, 0, 0]
    ]
    test(game_2048.move, (grid, 'W'), expected, "Merges correctly upward")


def test_add_new_tile():
    grid = [
        [2, 0, 0, 4],
        [0, 2, 2, 16],
        [0, 4, 64, 0],
        [4, 0, 32, 0]
    ]
    random.seed(0)  # fix randomness
    new_grid = copy.deepcopy(grid)
    game_2048.add_new_tile(new_grid)

    # Verify that exactly one new 2 was added in a previously empty position
    before = game_2048.get_empty_positions(grid)
    after = game_2048.get_empty_positions(new_grid)
    if(after != None):
        added = set(before) - set(after)
        valid = len(added) == 1 and all(
            new_grid[r][c] == 2 for r, c in added
        )
    else:
        valid = False

    def verify_add_new_tile(_):
        return valid

    test(verify_add_new_tile, (grid,), True, "Adds one new 2 in empty cell")



def test_can_move():
    no_moves = [
        [2, 4, 2, 8],
        [8, 16, 8, 2],
        [4, 2, 16, 4],
        [2, 8, 4, 2]
    ]
    test(game_2048.can_move, (no_moves,), False, "Detects no moves left")

    possible = [
        [2, 4, 2, 8],
        [8, 0, 8, 2],
        [4, 2, 16, 4],
        [2, 8, 4, 2]
    ]
    test(game_2048.can_move, (possible,), True, "Detects available move")

    merge_possible = [
        [2, 4, 2, 8],
        [8, 16, 8, 2],
        [4, 2, 16, 4],
        [4, 8, 4, 2]
    ]
    test(game_2048.can_move, (merge_possible,), True, "Detects adjacent merge move")


def test_game_won():
    win_64 = [
        [2, 4, 2, 8],
        [8, 16, 8, 2],
        [4, 2, 64, 4],
        [2, 8, 4, 2]
    ]
    test(game_2048.game_won, (win_64,), True, "Detects win condition with MAX_TILE=64")

    no_win_64 = [
        [2, 4, 2, 8],
        [8, 0, 8, 2],
        [4, 2, 16, 4],
        [2, 8, 4, 2]
    ]
    test(game_2048.game_won, (no_win_64,), False, "Detects no win condition with MAX_TILE=64")

    game_2048.MAX_TILE = 128 # change temporarily
    test(game_2048.game_won, (win_64,), False, "Detects no win condition with MAX_TILE=128")
    game_2048.MAX_TILE = 64  # reset back

def test_exceptions(function_to_test):
    """
    Test that function_to_test raises expected exceptions when given invalid inputs.
    """
    tests = [
        # (args, expected_exception, description)
        (("not a grid",), TypeError, "Raises TypeError on non-list grid"),
        (([[1, 2], [3, 4]],), ValueError, "Raises ValueError on wrong grid size"),
        ((
            [
                [2, 4, 0, 8],
                [0, 0, 'invalid', 0],
                [2, 0, 16, 32],
                [0, 8, 0, 0]
            ],
        ), ValueError, "Raises ValueError on invalid cell value"),
    ]

    for args, expected_exc, desc in tests:
        try:
            function_to_test(*args)
            print(f"{red('Test failed:')} {function_to_test.__name__} - {desc}")
            print("Expected exception:", expected_exc.__name__, "but none was raised.")
        except Exception as e:
            if isinstance(e, expected_exc):
                print(f"{green('Test passed:')} {function_to_test.__name__} - {desc}")
            else:
                print(f"{red('Test failed:')} {function_to_test.__name__} - {desc}")
                print(f"Raised {type(e).__name__}, expected {expected_exc.__name__}")
                print("Exception message:", e)

def main():
    if os.name == 'nt':
        os.system('color')

    test_display_grid()
    # Add test_exceptions for other functions as needed
    test_exceptions(game_2048.display_grid)
    test_init_grid()
    test_get_empty_positions()
    test_get_user_input()
    test_move()
    test_add_new_tile()
    test_can_move()
    test_game_won()

if __name__ == '__main__':
    main()
