import random
import matplotlib.pyplot as plt 
import numpy as np
import my_test as stud

def always_cooperate(my_hist, op_hist):
    """Toujours coopérer"""
    return True

def always_defect(my_hist, op_hist):
    """Toujours trahir"""
    return False

def coinflipper(my_hist, op_hist):
    """Flipper"""
    return random.random() > 0.8

def tft(my_hist, op_hist):
    """Chill Guy"""
    if len(op_hist)==0:
        return True
    else:
        return op_hist[-1]

def reward(my_coop, op_coop):
    if op_coop:
        if my_coop:
            return 3
        else:
            return 5
    else:
        if my_coop:
            return 0
        else:
            return 1


def tourney(strategies, weights, n_rounds):
    rewards = [0 for _ in strategies]
    total_weight = sum(weights)
    match_results = [[0 for __ in strategies] for _ in strategies]

    for p1 in range(len(strategies)):
        for p2 in range(len(strategies)):
            p1_strat = strategies[p1]
            p2_strat = strategies[p2]
            print(f"Match: {p1_strat.__doc__} contre {p2_strat.__doc__}")
            p1_history = []
            p2_history = []
            p1_score = 0
            p2_score = 0
            for _ in range(n_rounds):
                p1_move = p1_strat(p1_history, p2_history)
                p2_move = p2_strat(p2_history, p1_history)
                p1_history.append(p1_move)
                p2_history.append(p2_move)
                p1_score += reward(p1_move, p2_move)
                p2_score += reward(p2_move, p1_move)
            print(f'Résultat : {p1_score/n_rounds:.2f} vs {p2_score/n_rounds:.2f}')
            #score d'un tournoi = somme de pts de match * (weight_opp/total_weight) * (1/2n)
            #car chaque strat joue 2n matchs (n en p1 et n en p2) et on normalise le score par le poids de l'adversaire
            match_results[p1][p2] = p1_score
            match_results[p2][p1] = p2_score



    labels = [strat.__doc__ for strat in strategies]
    
    final_scores = [(labels[i], sum(match_results[i])) for i in range(len(match_results))]
    ranking = sorted(final_scores, key=lambda final_score: (-1)*final_score[1])

    print('========RANKING FINAL========')
    for i in range(len(ranking)):
        print(f"{i+1}. {ranking[i][0]} avec un score de {ranking[i][1]}")


    # Plot the heatmap
    plt.figure(figsize=(16, 12))
    plt.imshow(match_results, cmap='YlGnBu', interpolation='nearest')
    plt.colorbar(label="Score")
    plt.xticks(np.arange(len(strategies)), labels)
    plt.yticks(np.arange(len(strategies)), labels)
    plt.show()

    return match_results


if __name__ == "__main__":


    strategies_local = [
        always_cooperate,
        always_defect,
        tft,
        coinflipper,
        stud.my_strategy
    ]
    strategies = strategies_local

    weights = [1 for _ in strategies]
    weights = [w / sum(weights) for w in weights] #normalizing  

    tourney(strategies, weights, random.randint(10, 20))



                
