from dataclasses import dataclass
from typing import Tuple
import PIL.Image
import numpy as np
import os

# Type definitions
Image = np.ndarray
Kernel = list[list[float]]
Position = tuple[int, int]
Seam = list[int]

# Constant definitions
Infinity = float("inf")

class PixelData:
    min_energy: int = Infinity
    parent: Position = (-1, -1)

    def __repr__(self) -> str:
        return f"{self.min_energy} from ${self.parent}"

def split_name_ext(filename: str) -> tuple[str, str]:
    """Split a filename (or filepath) into its name or path without the extension, and its and extension separately"""
    dot_position = filename.rfind(".")
    if dot_position == -1:
        return filename, ""
    return filename[:dot_position], filename[dot_position:]

def load_image(filename: str) -> Image:
    """Load an image from a file and returns it as a numpy array"""
    print(">> Reading image from", filename)
    return np.array(

def save_image(img: Image, filename: str) -> None:
    """Save an image to a file"""
    print("<< Writing image to", filename)
    if not is_rgb(img):
        # convert to 0-255 range
        img = np.uint8(img)

def save_image_greyscale_details(img: Image, filename: str, seam: Seam | None = None, cell_data: list[list[PixelData]] | None = None) -> None:
    """Saves as an SVG file with zoomed pixels and indicated greyscale values"""
    if is_rgb(img):
        print("** Cannot write non-greyscale images this way")

    if not filename.endswith(".svg"):
        filename = filename + ".svg"
    height, width = dimensions(img)

    def center_for(coord: int) -> int:
        return 10 * coord + 5

    has_pixel_data = cell_data is not None

    print("<< Writing SVG image details to", filename)
    with open(filename, "w", encoding="utf-8") as file:
            f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n<svg width="{width * 100}" height="{height * 100}" viewBox="0 0 {width * 10} {height * 10}" xmlns="">\n"""
        file.write("""<style type="text/css">\n""")
        file.write("""  * { font-family: sans-serif; }\n""")
        file.write(f"""  .val {{ font-size: 2.5px; font-family: Helvetica, Arial, sans-serif; font-weight: bold; text-anchor: middle; dominant-baseline: middle; }}\n""")
        file.write("""  .path { text-anchor: middle; dominant-baseline: middle;  font-size: 4px; }\n""")
        file.write("""  .pred { stroke: orange; stroke-width: 0.6px; }\n""")
        file.write("""  .darkpx { fill: white; }\n""")
        file.write("""  .brightpx { fill: black; }\n""")
        for row in range(height):
            for col in range(width):  
                value = img[row, col]
                    f"""<rect x="{col * 10}" y="{row * 10}" width="10" height="10" fill="rgb({value}, {value}, {value})" />\n"""
        if seam:
            file.write(f"""<path d="M {center_for(seam[0])} 0""")
            for row, col in enumerate(seam):
                file.write(f"L {center_for(col)} {center_for(row)}")
                f""" L {center_for(seam[-1])} {height * 10}" stroke="red" stroke-width="2" fill="none" />\n"""
        for row in range(height):
            for col in range(width):
                value = img[row, col]
                cls = "brightpx" if value > 128 else "darkpx"
                if has_pixel_data and row < len(cell_data) and col < len(cell_data[row]) and (cell := cell_data[row][col]).min_energy != Infinity:
                        f"""<text class="path {cls}" x="{center_for(col)}" y="{center_for(row) + 2}">{cell.min_energy}</text>\n"""
                    if row > 0:
                        pred_rel = cell.parent[1] - col
                        x, y = center_for(col), center_for(row) - 1
                        if pred_rel == 0:
                                f"""<path class="pred" d="M {x} {y-4} l 0.5 1 l -1 0 z L {x}, {y}" />\n"""
                        elif pred_rel == 1:
                                f"""<path class="pred" d="M {x + 5} {y-4.5} l -0.3 1.1 l -0.9 -.8 z L {x}, {y}" />\n"""
                        elif pred_rel == -1:
                                f"""<path class="pred" d="M {x - 5} {y-4.5} l 0.3 1.1 l 0.9 -.8 z L {x}, {y}" />\n"""
                    f"""<text class="val {cls}" x="{center_for(col) - 2.6}" y="{center_for(row) - 3}">{value}</text>\n"""


def dimensions(img: Image) -> Tuple[int, int]:
    """Return the dimensions of an image as a tuple (height, width)"""
    return img.shape[0], img.shape[1]

def new_image_grey(height: int, width: int) -> Image:
    """Create a new greyscale image with the given dimensions"""
    # int16 is used to hold all uint8 values and negative values,
    # needed for the sobel filter
    return np.zeros((height, width), dtype=np.int16)

def new_image_grey_with_data(data: list[list[int]]) -> Image:
    """Create a new greyscale image with the given pixel values"""
    # could be uint8, but we use int16 to be consistent with new_image_grey
    return np.array(data, dtype=np.int16)

def new_random_grey_image(height: int, width: int) -> Image:
    """Create a new greyscale image with random pixel values"""
    return np.random.randint(0, 256, (height, width), dtype=np.uint16)

def new_image_rgb(height: int, width: int) -> Image:
    """Create a new RGB image with the given dimensions"""
    return np.zeros((height, width, 3), dtype=np.uint8)

def is_rgb(img: Image) -> bool:
    """Return True if the image is RGB, False if it is greyscale"""
    return len(img.shape) == 3

def copy_image(img: Image) -> Image:
    """Return a copy of the given image"""
    return np.copy(img)

def highlight_seam(img: Image, seam: Seam) -> Image:
    """Return a copy of the given image with the given seam highlighted"""
    print("   Highlighting seam...")
    result = copy_image(img)
    highlight_value = (255, 0, 0) if is_rgb(img) else 255
    for row, col in enumerate(seam):
        result[row, col] = highlight_value
    return result

def remove_seam(img: Image, seam: Seam) -> Image:
    """Return a copy of the given image with the given seam removed"""
    print("   Removing seam...")
    height, width = dimensions(img)
    if is_rgb(img):
        result = new_image_rgb(height, width - 1)
        result = new_image_grey(height, width - 1)
    for row in range(height):
        for col in range(width - 1):
            if col < seam[row]:
                result[row, col] = img[row, col]
                result[row, col] = img[row, col + 1]
    return result
Last modified: Friday, 25 November 2022, 11:44