# =============================================================================
# ICC/Programming -- Week 12 Exercises: SOLUTIONS
# =============================================================================
# Each solution is accompanied by a detailed explanation of the key idea.
# Run the whole file with:  python3 threads_solution.py
# =============================================================================

from threading import Thread, current_thread, Lock
from time import sleep, perf_counter
import random

# =============================================================================
# EXERCISE 1 -- Creating and Starting Threads
# =============================================================================

# -----------------------------------------------------------------------------
# 1(a) -- Hello from a thread
# -----------------------------------------------------------------------------
# EXPLANATION:
#   "Main thread: both threads started" is NOT guaranteed to be the first line
#   printed.  t1.start() and t2.start() hand the threads to the OS scheduler,
#   which may immediately begin executing greet() before the main thread
#   reaches its print statement.  The exact order is non-deterministic.

def greet() -> None:
    name = current_thread().name
    for i in range(4):
        print(f"Hello from {name} (iteration {i})")
        sleep(0.5)

t1 = Thread(target=greet, name="Alpha")
t2 = Thread(target=greet, name="Beta")
t1.start()
t2.start()
print("Main thread: both threads started")
t1.join()
t2.join()

# -----------------------------------------------------------------------------
# 1(b) -- Joining threads
# -----------------------------------------------------------------------------
# EXPLANATION:
#   t.join() blocks the calling thread (here: the main thread) until thread t
#   has finished executing.  By joining both t1 and t2 before the final print,
#   we guarantee that "Main thread: all done" is the very last line printed.

t1 = Thread(target=greet, name="Alpha")
t2 = Thread(target=greet, name="Beta")
t1.start()
t2.start()
t1.join()   # wait for Alpha to finish
t2.join()   # wait for Beta  to finish
print("Main thread: all done")

# -----------------------------------------------------------------------------
# 1(c) -- Passing arguments to a thread
# -----------------------------------------------------------------------------
# EXPLANATION:
#   The Thread constructor's `args` keyword is a tuple forwarded to the target
#   function as positional arguments.  We complete the function body and pass
#   the right values so Alice greets 3 times and Bob 5 times.

def greet_n(name: str, n: int) -> None:
    for _ in range(n):
        print(f"Hi, I am {name}")
        sleep(0.3)

t1 = Thread(target=greet_n, args=("Alice", 3))
t2 = Thread(target=greet_n, args=("Bob", 5))
t1.start(); t2.start()
t1.join();  t2.join()

# -----------------------------------------------------------------------------
# 1(d) -- Daemon threads
# -----------------------------------------------------------------------------
# EXPLANATION:
#   A *daemon* thread is automatically killed when the main thread exits.
#   The process does not wait for it to finish, unlike a normal thread.
#
#   Version A (normal thread): the process stays alive until tick() completes
#   all 10 iterations even though the main thread reached its print statement.
#
#   Version B (daemon=True): as soon as the main thread prints "Main thread
#   done" and exits, the daemon is killed mid-loop.
#
#   For the heartbeat:  sleep(1.0) gives roughly 1.0 / 0.2 = ~5 heartbeats.
#   The exact count varies because thread scheduling is not perfectly timed.
#   If the thread were NOT a daemon, the process would run forever (infinite
#   loop) and never exit after the main thread finishes.

def tick() -> None:
    for i in range(10):
        print(f"tick {i}")
        sleep(0.4)

# Version A
t = Thread(target=tick)
t.start()
print("Main thread done (A)")
t.join()   # needed here so Version A finishes before Version B starts

# Version B
t = Thread(target=tick, daemon=True)
t.start()
print("Main thread done (B)")
# no join -- main thread exits immediately, daemon is killed

sleep(0.05)  # tiny pause so Version B output flushes before heartbeat starts

def heartbeat() -> None:
    while True:
        print("heartbeat")
        sleep(0.2)

hb = Thread(target=heartbeat, daemon=True)
hb.start()
sleep(1.0)
# main thread exits here; the heartbeat daemon is killed automatically

# -----------------------------------------------------------------------------
# 1(e) -- A thread that returns a value
# -----------------------------------------------------------------------------
# EXPLANATION:
#   Thread.run() always returns None, so there is no way to retrieve a return
#   value directly.  The standard workaround is to write the result into a
#   shared container (here a list) before the thread exits, then read it from
#   the main thread after joining.
#   sum(range(start, stop)) is the idiomatic one-liner for summing a range.

results: list[int] = []

def sum_range(start: int, stop: int) -> None:
    total = sum(range(start, stop))  # sum of integers from start up to (not including) stop
    results.append(total)            # store result so the main thread can read it

t1 = Thread(target=sum_range, args=(0, 500_000))
t2 = Thread(target=sum_range, args=(500_000, 1_000_000))
t1.start()
t2.start()
t1.join()
t2.join()
print(sum(results))  # 499999500000


# =============================================================================
# EXERCISE 2 -- Race Conditions
# =============================================================================

# -----------------------------------------------------------------------------
# 2(a) -- Observing the race
# -----------------------------------------------------------------------------
# EXPLANATION:
#   counter += 1 compiles to three bytecode steps:
#     LOAD  counter          (read current value into a register)
#     ADD   1                (compute new value)
#     STORE counter          (write back)
#   The OS can switch between threads between ANY of these steps.
#
#   Example interleaving that loses one increment (initial counter = 5):
#
#     Thread 1: LOAD  -> gets 5
#     Thread 2: LOAD  -> gets 5          (T1 hasn't stored yet)
#     Thread 1: ADD 1 -> computes 6
#     Thread 1: STORE -> counter = 6
#     Thread 2: ADD 1 -> computes 6      (based on stale value 5)
#     Thread 2: STORE -> counter = 6     (overwrites T1's result!)
#
#   Both threads did one increment but the counter only went from 5 to 6
#   instead of 5 to 7 -- one increment was lost.

counter: int = 0

def increment(n: int) -> None:
    global counter
    for _ in range(n):
        counter += 1

N = 100_000
for _ in range(50):
    counter = 0
    t1 = Thread(target=increment, args=(N,))
    t2 = Thread(target=increment, args=(N,))
    t1.start(); t2.start()
    t1.join();  t2.join()
    print(f"[2a] Expected: {2 * N}, Got: {counter}")

# =============================================================================
# EXERCISE 3 -- Fixing Race Conditions with Locks
# =============================================================================

# -----------------------------------------------------------------------------
# 3(a) -- Thread-safe counter
# -----------------------------------------------------------------------------
# EXPLANATION:
#   Wrapping counter += 1 in `with lock:` makes the three-step LOAD/ADD/STORE
#   sequence atomic with respect to other threads: no other thread can enter
#   the block until the current thread exits it and releases the lock.

counter = 0
lock = Lock()

def safe_increment(n: int) -> None:
    global counter
    for _ in range(n):
        with lock:
            counter += 1

N = 100_000
t1 = Thread(target=safe_increment, args=(N,))
t2 = Thread(target=safe_increment, args=(N,))
t1.start(); t2.start()
t1.join();  t2.join()
print(f"[3a] Expected: {2 * N}, Got: {counter}")  # always 200 000

# -----------------------------------------------------------------------------
# 3(b) -- Thread-safe bank account
# -----------------------------------------------------------------------------
# EXPLANATION:
#   Each method acquires the lock before touching self.balance, so only one
#   thread at a time can read-modify-write the balance.  The 1000 deposits of
#   10 and 1000 withdrawals of 10 cancel out, leaving the balance at 1000.

class BankAccount:
    def __init__(self, initial: int) -> None:
        self.balance: int = initial
        self.lock = Lock()

    def deposit(self, amount: int) -> None:
        with self.lock:
            self.balance += amount

    def withdraw(self, amount: int) -> None:
        with self.lock:
            self.balance -= amount

account = BankAccount(1000)

def do_deposits() -> None:
    for _ in range(1000):
        account.deposit(10)

def do_withdrawals() -> None:
    for _ in range(1000):
        account.withdraw(10)

t1 = Thread(target=do_deposits)
t2 = Thread(target=do_withdrawals)
t1.start(); t2.start()
t1.join();  t2.join()
print(f"[3b] Final balance: {account.balance}")  # always 1000

# -----------------------------------------------------------------------------
# 3(c) -- Lock granularity: measure the trade-off
# -----------------------------------------------------------------------------
# EXPLANATION:
#   Fine-grained: lock acquired and released N times per thread.
#     Each acquire/release is a system call -- expensive at scale.
#   Coarse-grained: lock acquired once per thread for the whole loop.
#     Far fewer system calls, so much faster.
#     Still correct here because each thread holds the lock for its entire
#     loop -- no other thread can interleave.
#
#   When would coarse-grained be a bad idea?
#   If the locked section is very long or does I/O, holding the lock the
#   whole time starves other threads and eliminates any benefit of
#   concurrency.  Fine-grained locking keeps the critical section small so
#   threads can make progress in parallel as much as possible.

counter = 0

def safe_increment_coarse(n: int) -> None:
    global counter
    with lock:          # acquire once
        for _ in range(n):
            counter += 1    # loop inside the lock

def run(fn) -> float:
    """Reset counter, run fn in two threads, return elapsed seconds."""
    global counter
    counter = 0
    t1 = Thread(target=fn, args=(N,))
    t2 = Thread(target=fn, args=(N,))
    start = perf_counter()
    t1.start(); t2.start()
    t1.join();  t2.join()
    return perf_counter() - start

fine   = run(safe_increment)
coarse = run(safe_increment_coarse)
print(f"[3c] Fine-grained:   {fine:.3f}s  result={counter}")
print(f"[3c] Coarse-grained: {coarse:.3f}s  result={counter}")
# Coarse-grained is typically 5-10x faster on this benchmark.


# =============================================================================
# EXERCISE 4 -- Putting It All Together
# =============================================================================

# -----------------------------------------------------------------------------
# 4(a) + 4(b) -- TicketOffice setup and booking implementation
# -----------------------------------------------------------------------------
# EXPLANATION:
#   book() acquires self.lock before checking and updating self.seats.
#   This guarantees that no two threads can simultaneously pass the
#   `if self.seats >= n` check and both subtract seats -- which could
#   push seats below zero without the lock.

class TicketOffice:
    def __init__(self, total_seats: int) -> None:
        self.seats: int = total_seats
        self.lock = Lock()

    def book(self, customer: str, n: int) -> None:
        with self.lock:
            if self.seats >= n:
                self.seats -= n
                print(f"✓ {customer} booked {n} seat(s). Remaining: {self.seats}")
            else:
                print(f"✗ {customer} failed -- only {self.seats} left")

office = TicketOffice(total_seats=20)

# -----------------------------------------------------------------------------
# 4(c) -- Launch customer threads
# -----------------------------------------------------------------------------
# EXPLANATION:
#   We build all Thread objects first, start them all, then join them all.
#   This pattern lets all threads run concurrently.  Starting and immediately
#   joining inside the same loop would serialize them -- defeating the purpose
#   of threading.  Without the lock in book(), two threads could both read
#   self.seats >= n as True and both subtract, driving the count negative.

customers = [
    Thread(target=office.book, args=(f"Customer-{i}", random.randint(1, 4)))
    for i in range(10)
]

for t in customers:      # start all first ...
    t.start()

for t in customers:      # ... then join all
    t.join()

print(f"\n[4c] Seats remaining: {office.seats}")

# -----------------------------------------------------------------------------
# 4(d) BONUS -- Fix the deadlock
# -----------------------------------------------------------------------------
# EXPLANATION of the deadlock:
#   t1 acquires office_a.lock, sleeps, then tries to acquire office_b.lock.
#   t2 acquires office_b.lock, sleeps, then tries to acquire office_a.lock.
#   If the OS switches between them after each first acquisition:
#     t1 holds A's lock, waits for B's lock
#     t2 holds B's lock, waits for A's lock
#   Neither can proceed -- classic circular wait / deadlock.
#
# The buggy version is NOT run here (it would hang).  Shown for reference only:
#
#   def transfer(src, dst, n):
#       with src.lock:
#           sleep(0.01)
#           with dst.lock:           # <-- deadlock here if dst.lock already held
#               src.seats -= n
#               dst.seats += n
#
# FIX -- consistent lock ordering by object identity:
#   Always acquire the lock whose id() is smaller first.
#   Both threads now request locks in the same global order, so the second
#   thread simply blocks on the first lock rather than forming a cycle.

def transfer_fixed(src: TicketOffice, dst: TicketOffice, n: int) -> None:
    first, second = (src, dst) if id(src) < id(dst) else (dst, src)
    with first.lock:
        with second.lock:
            src.seats -= n
            dst.seats += n

office_a = TicketOffice(50)
office_b = TicketOffice(50)

t1 = Thread(target=transfer_fixed, args=(office_a, office_b, 5))
t2 = Thread(target=transfer_fixed, args=(office_b, office_a, 5))
t1.start(); t2.start()
t1.join();  t2.join()
print(f"[4d] office_a={office_a.seats}, office_b={office_b.seats}")  # both 50