import numpy as np
from PIL import Image
from scipy.signal import convolve2d


def grayscale(numpy_image):
    # Formule NTSC vue dans les exercices de la semaine 14 : 0.299 * R + 0.587 * G + 0.114 * B
    res = np.zeros(numpy_image.shape[0:2], dtype=np.int32)
    return res


def apply_blur(numpy_image):
    res = np.zeros(numpy_image.shape, dtype=np.int32)

    # Filtre donné dans l'énoncé
    box_blur = np.array([
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]
    ) / 1

    return res

def sobel_analysis(numpy_image):
    # Filtres donnés dans l'énoncé
    sobel_x_np = np.zeros(numpy_image.shape, dtype=np.int32)
    sobel_y_np = np.zeros(numpy_image.shape, dtype=np.int32)

    # Formules données dans l'énoncé
    magnitude = np.zeros(numpy_image.shape, dtype=np.int32)
    gradient = np.zeros(numpy_image.shape, dtype=np.int32)

    return sobel_x_np, sobel_y_np, magnitude, gradient

def non_maximum_suppression(magnitude, grad):
    res = np.zeros(magnitude.shape, dtype=np.int32)

    return res

def threshold(img, seuil):
    res = np.zeros(img.shape, dtype=np.int32)

    return res

# Lecture de l'image et enchaînement des fonctions définies dans la donnée
image_np = np.array(Image.open('in/go.jpg'))
gray_np = grayscale(image_np)
blurred_np = apply_blur(gray_np)
sobel_x_np, sobel_y_np, magnitude_np, gradient_np = sobel_analysis(blurred_np)
nms_np = non_maximum_suppression(magnitude_np, gradient_np)
print(np.unique(nms_np))
thresholded = threshold(nms_np, 200)

# Visualisation
gray_np = gray_np.astype(np.uint8)
blurred_np = blurred_np.astype(np.uint8)
sobel_x_np = ((sobel_x_np - sobel_x_np.min()) / (sobel_x_np.max() - sobel_x_np.min()) * 255).astype(np.uint8)
sobel_y_np = ((sobel_y_np - sobel_y_np.min()) / (sobel_y_np.max() - sobel_y_np.min()) * 255).astype(np.uint8)
magnitude_np = ((magnitude_np - magnitude_np.min()) / (magnitude_np.max() - magnitude_np.min()) * 255).astype(np.uint8)
gradient_np = ((gradient_np - gradient_np.min()) / (gradient_np.max() - gradient_np.min()) * 255).astype(np.uint8)
nms_np = ((nms_np - nms_np.min()) / (nms_np.max() - nms_np.min())*255).astype(np.uint8)
thresholded = ((thresholded - thresholded.min()) / (thresholded.max() - thresholded.min()) * 255).astype(np.uint8)

Image.fromarray(gray_np).save('out/exo02_gray.jpg')
Image.fromarray(blurred_np).save('out/exo02_blurred.jpg')
Image.fromarray(sobel_x_np).save('out/exo02_sobelx.jpg')
Image.fromarray(sobel_y_np).save('out/exo02_sobely.jpg')
Image.fromarray(magnitude_np).save('out/exo02_magnitude.jpg')
Image.fromarray(gradient_np).save('out/exo02_gradient.jpg')
Image.fromarray(nms_np).save('out/exo02_nms.jpg')
Image.fromarray(thresholded).save('out/exo02_thresholded.jpg')
