import numpy as np
from PIL import Image

def correlation(A, B):
    return np.sum(A * B)

def convolution(A, B):
    return np.sum(A * np.rot90(np.rot90(B)))

def convolution_over_2dimage(image, filter):
    result = np.empty((image.shape[0] - filter.shape[0] + 1, image.shape[1] - filter.shape[1] + 1),
                      dtype=np.int32)

    for x in range(result.shape[0]):
        for y in range(result.shape[1]):
            result[x, y] = convolution(image[x:x+filter.shape[0], y:y+filter.shape[1]], filter)
    return result

def convolution_over_3dimage(image, filter):
    assert len(image.shape) == 3 and image.shape[2] == 3
    result = np.empty((image.shape[0] - filter.shape[0] + 1, image.shape[1] - filter.shape[1] + 1, image.shape[2]),
                      dtype=np.int32)

    for rgb in range(3):
        result[:,:,rgb] = convolution_over_2dimage(image[:,:,rgb], filter)

    return result

A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

box_blur = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])/ 9
sharpen = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])

color_pil = Image.open('in/starry.jpg')
grayscale_pil = color_pil.convert("L")
color_np = np.array(color_pil).astype(np.int32)
gray_np = np.array(grayscale_pil).astype(np.int32)

print(correlation(A, B))
print(convolution(A, B))

gray_np = gray_np.astype(np.uint8)
gray_blur = convolution_over_2dimage(gray_np, box_blur).astype(np.uint8)
gray_sharpen = convolution_over_2dimage(gray_np, sharpen).clip(0,255).astype(np.uint8)

color_np = color_np.astype(np.uint8)
color_blur = convolution_over_3dimage(color_np, box_blur).astype(np.uint8)
color_sharpen = convolution_over_3dimage(color_np, sharpen)
color_sharpen = color_sharpen.clip(0,255).astype(np.uint8)

first_row =np.vstack( [
    np.repeat(gray_np[1:-1, 1:-1, np.newaxis], 3, axis=2),
    np.repeat(gray_blur[:, :, np.newaxis], 3, axis=2),
    np.repeat(gray_sharpen[:, :, np.newaxis], 3, axis=2)
])
second_row =np.vstack([
    color_np[1:-1,1:-1],
    color_blur,
    color_sharpen
])
Image.fromarray(np.hstack([first_row, second_row])).save('out/exo01.jpg')