# =============================================================================
# ICC/Programming -- Week 13: Cryptography
# SOLUTIONS with detailed comments
# =============================================================================

from time import perf_counter
import math


# =============================================================================
# EXERCISE 1 -- Symmetric Encryption: XOR Cipher
# =============================================================================

# -----------------------------------------------------------------------------
# 1(a) -- Basic XOR encrypt/decrypt
# -----------------------------------------------------------------------------

def xor_cipher(text: str, key: int) -> str:
    """
    Encrypt (or decrypt) a string by XOR-ing every character with `key`.

    XOR truth table:
        0 XOR 0 = 0
        0 XOR 1 = 1
        1 XOR 0 = 1
        1 XOR 1 = 0   <-- identical bits cancel out

    Key property -- self-inverse: (m XOR k) XOR k = m
    XOR-ing a bit with 1 flips it; flipping it again restores it.
    Therefore the same function and the same key both encrypt and decrypt.

    Worked example:
        ord('H') = 72 = 0b01001000
        key  42       = 0b00101010
        XOR           = 0b01100010  (= 98, encrypted character)
        XOR key again = 0b01001000  (= 72 = 'H', restored)
    """
    return ''.join(chr(ord(c) ^ key) for c in text)


message    = "HELLO"
key        = 42
ciphertext = xor_cipher(message, key)
decrypted  = xor_cipher(ciphertext, key)

print("=== Exercise 1(a) ===")
print("Round-trip OK:", decrypted == message)   # must be True


# -----------------------------------------------------------------------------
# 1(b) -- Encrypting a longer message (multi-byte key)
# -----------------------------------------------------------------------------

def xor_cipher_multi(text: str, key: list[int]) -> str:
    """
    XOR cipher with a list of key bytes.

    The character at position i is XOR-ed with key[i % len(key)],
    so the key repeats ("wraps around") once the message exceeds the
    key length.

    Security note:
    - If len(key) == len(text) and each byte is truly random, this is a
      one-time pad -- provably unbreakable.
    - If len(key) < len(text), the repeating pattern can be exploited by
      an attacker who knows or guesses the key length.
    """
    result = []
    for i, c in enumerate(text):
        key_byte = key[i % len(key)]    # cycle through key bytes
        result.append(chr(ord(c) ^ key_byte))
    return ''.join(result)


key_multi = [42, 17, 99]
msg       = "ATTACK AT DAWN"
enc       = xor_cipher_multi(msg, key_multi)
dec       = xor_cipher_multi(enc, key_multi)   # same call decrypts

print("\n=== Exercise 1(b) ===")
print("Round-trip OK:", dec == msg)     # must be True


# -----------------------------------------------------------------------------
# 1(c) -- Key length and security (discussion answers)
# -----------------------------------------------------------------------------

print("\n=== Exercise 1(c) ===")

# (i) Key-space grows exponentially with key length:
#     1-byte key  : 256^1 =              256 keys
#     2-byte key  : 256^2 =           65,536 keys
#     4-byte key  : 256^4 =    4,294,967,296 keys (~4 billion)
#     n-byte key  : 256^n keys
for nb in [1, 2, 4, 8]:
    print(f"  {nb}-byte key: {256**nb:>25,} possible keys")

# (ii) AES-128 uses a 128-bit key:
#      2^128 ≈ 3.4 × 10^38 possible keys.
#      Even at 10^18 guesses per second (far beyond any hardware today),
#      exhausting the key space would take ~10^13 years.
print(f"\n  AES-128 key space: 2^128 = {2**128:.3e} possible keys")


# =============================================================================
# EXERCISE 2 -- Public-Key Encryption: Toy RSA
# =============================================================================
#
# RSA key generation recap (see exercise sheet background):
#   1. Pick two primes p and q.  Set n = p * q.
#   2. Compute phi(n) = (p-1) * (q-1)  [Euler's totient].
#   3. Choose public exponent e with 1 < e < phi(n) and gcd(e, phi(n)) = 1.
#   4. Compute private exponent d = modular inverse of e mod phi(n),
#      i.e. e * d ≡ 1 (mod phi(n)).
#   Public key:  (e, n)   -- shared with everyone
#   Private key: (d, n)   -- kept secret
#   Encrypt: c = m^e mod n
#   Decrypt: m = c^d mod n


# -----------------------------------------------------------------------------
# 2(a) -- Verify toy parameters by hand (pen-and-paper, no code needed)
# -----------------------------------------------------------------------------

print("\n=== Exercise 2(a) -- hand verification ===")

# (i)  p=3 is prime (divisible only by 1 and 3).
#      q=11 is prime (divisible only by 1 and 11).

# (ii) n = 3 * 11 = 33.  phi(n) = (3-1)*(11-1) = 2*10 = 20.

# (iii) e*d = 3*7 = 21.  21 mod 20 = 1. ✓
#        So d=7 is the modular inverse of e=3 mod 20.

# (iv)  Encrypt m=4:  4^3 mod 33 = 64 mod 33 = 31.  c = 31.
#       Decrypt c=31: 31^7 mod 33.
#         31^2 = 961;   961 mod 33 = 4
#         31^4 = 4^2  = 16  (mod 33)
#         31^6 = 4*16 = 64 mod 33 = 31
#         31^7 = 31*31 mod 33 = 961 mod 33 = 4  ✓  We recover m=4.
p_toy, q_toy = 3, 11
n_toy        = p_toy * q_toy
phi_toy      = (p_toy - 1) * (q_toy - 1)
e_toy, d_toy = 3, 7
print(f"  n = {n_toy}, phi(n) = {phi_toy}")
print(f"  e*d mod phi(n) = {(e_toy * d_toy) % phi_toy}  (must be 1)")
m_test = 4
c_test = pow(m_test, e_toy, n_toy)
r_test = pow(c_test, d_toy, n_toy)
print(f"  Encrypt {m_test} -> {c_test}, decrypt -> {r_test}  (must be {m_test})")


# -----------------------------------------------------------------------------
# 2(b) -- Modular exponentiation: naive loop vs built-in pow()
# -----------------------------------------------------------------------------

def mod_exp_naive(base: int, exp: int, mod: int) -> int:
    """
    Compute base^exp mod m using exp multiplications.
    Time complexity: O(exp).

    For real RSA, exp (the public or private exponent) has ~2048 bits,
    i.e. exp ~ 10^617.  Running 10^617 loop iterations is impossible.

    Python's built-in pow(base, exp, mod) uses "square-and-multiply":
    it needs only O(log exp) ≈ 2048 multiplications for a 2048-bit
    exponent -- instant even on modest hardware.
    """
    result = 1
    for _ in range(exp):
        result = (result * base) % mod   # one multiplication per iteration
    return result


print("\n=== Exercise 2(b) ===")
e_toy, d_toy, n_toy = 3, 7, 33
for m in range(1, 10):
    assert mod_exp_naive(m, e_toy, n_toy) == pow(m, e_toy, n_toy), \
        f"Failed for m={m}"
print("mod_exp_naive correct for m = 1..9")


# -----------------------------------------------------------------------------
# 2(c) -- Encrypt and decrypt a word; failure when m >= n
# -----------------------------------------------------------------------------

def rsa_encrypt(plaintext: str, e: int, n: int) -> list[int]:
    """
    Encrypt each character: ciphertext[i] = ord(c)^e mod n.

    CRITICAL REQUIREMENT: ord(c) must be strictly less than n for every
    character.  If ord(c) >= n, the modular reduction loses information
    and decryption fails.
    """
    return [pow(ord(c), e, n) for c in plaintext]


def rsa_decrypt(ciphertext: list[int], d: int, n: int) -> str:
    """Decrypt: plaintext[i] = chr(ciphertext[i]^d mod n)."""
    return ''.join(chr(pow(c, d, n)) for c in ciphertext)


print("\n=== Exercise 2(c) ===")

# "HI" fails because ord('H')=72 and ord('I')=73, both >= n=33.
#
# Why must m < n?  Intuitive explanation:
# Encryption computes c = m^e mod n, so c always lands in {0,...,n-1}
# regardless of how large m is.  Two values of m that differ by a
# multiple of n produce identical ciphertext -- decryption cannot tell
# them apart.  RSA only "sees" m mod n; any information in m beyond that
# is permanently lost.
# Concretely: ord('H') = 72 = 2*33 + 6, so RSA treats it as 6;
# decryption recovers 6 and produces chr(6), not 'H'.
e_toy, d_toy, n_toy = 3, 7, 33
enc_fail = rsa_encrypt("HI", e_toy, n_toy)
dec_fail = rsa_decrypt(enc_fail, d_toy, n_toy)
print(f"'HI' with n=33: ord('H')={ord('H')}, ord('I')={ord('I')} -- both >= n")
print(f"  Decrypted: {repr(dec_fail)}  (WRONG, as expected)")


# -----------------------------------------------------------------------------
# 2(d) -- Slightly larger parameters (n=3233, covers full printable ASCII)
# -----------------------------------------------------------------------------

# Key generation for reference:
#   p = 61, q = 53      => n = 3233,  phi(n) = 60*52 = 3120
#   e = 17              (gcd(17, 3120) = 1)
#   d = 2753            (17 * 2753 = 46801 = 15*3120 + 1,
#                        so 17 * 2753 mod 3120 = 1)
e2, d2, n2 = 17, 2753, 3233
phi_n2     = (61 - 1) * (53 - 1)   # = 3120

print("\n=== Exercise 2(d) ===")

# (i) Encrypt then decrypt "Hello, Bob!" and confirm with assert.
#     All characters in "Hello, Bob!" have code points < 3233, so
#     the m < n requirement is satisfied for every character.
msg2 = "Hello, Bob!"
enc2 = rsa_encrypt(msg2, e2, n2)
dec2 = rsa_decrypt(enc2, d2, n2)
print("Encrypted:", enc2)
print("Decrypted:", dec2)
assert dec2 == msg2, "Round-trip failed!"
print("(i) Round-trip OK")

# (ii) Distinct encryptable integer values = n = 3233
#      (any integer m satisfying 0 <= m < 3233).
#      All printable ASCII code points (32-126) are < 3233, so any
#      printable ASCII string can be encrypted character-by-character.
print(f"\n(ii) Distinct encryptable values: {n2}  (integers 0 to {n2 - 1})")

# (iii) Verify the private exponent: e * d mod phi(n) must equal 1.
#       This confirms that d is the true modular inverse of e,
#       i.e. decryption undoes encryption exactly.
check = (e2 * d2) % phi_n2
print(f"\n(iii) e2 * d2 mod phi(n) = {e2} * {d2} mod {phi_n2} = {check}  (must be 1)")


# -----------------------------------------------------------------------------
# 2(e) -- Factorising n to break RSA; infeasibility at real scale
# -----------------------------------------------------------------------------

def factorise(n: int) -> tuple[int, int]:
    """
    Trial division: test every integer from 2 up to sqrt(n).
    If n has a prime factor p <= sqrt(n), we find it; the co-factor is n//p.

    Time complexity: O(sqrt(n)) divisions.
    For an n with k bits, sqrt(n) ~ 2^(k/2).
    - k=12 bits (n=3233): sqrt(3233) ~ 56  -- a few dozen steps, instant.
    - k=2048 bits:        sqrt(n) ~ 2^1024 ~ 10^308  -- completely impossible.
    """
    for p in range(2, math.isqrt(n) + 1):
        if n % p == 0:
            return p, n // p
    return 1, n     # n is prime


print("\n=== Exercise 2(e) ===")
start = perf_counter()
p_found, q_found = factorise(3233)
elapsed = perf_counter() - start
print(f"Factored 3233 = {p_found} * {q_found}  in {elapsed:.6f} s")

# (i) Trial division needs at most sqrt(n) divisions.
#     For a k-bit number, that is at most 2^(k/2) divisions.

# (ii) For 2048-bit RSA: sqrt(2^2048) = 2^1024 ~ 1.8 * 10^308 divisions.
#      Working in log10 to avoid float overflow:
log10_sec = 1024 * math.log10(2) - 18          # log10(2^1024 / 10^18)
log10_yrs = log10_sec - math.log10(365.25 * 24 * 3600)
print(f"\n(ii) Steps for 2048-bit n: ~2^1024 ≈ 10^{1024 * math.log10(2):.0f}")
print(f"     At 10^18 steps/s: ~10^{log10_sec:.0f} s = ~10^{log10_yrs:.0f} years")
print(f"     (Age of universe ≈ 1.4×10^10 years -- completely infeasible)")

# (iii) Recover the private key from the factors, then verify it works.
#
# Step 1: recompute phi(n) from the two recovered primes.
phi_recovered = (p_found - 1) * (q_found - 1)

# Step 2: compute d as the modular inverse of e2 mod phi(n).
#         Python's three-argument pow(a, -1, m) computes this directly.
d_recovered = pow(e2, -1, phi_recovered)

print(f"\n(iii) phi(n) = ({p_found}-1)*({q_found}-1) = {phi_recovered}")
print(f"      Recovered d = pow({e2}, -1, {phi_recovered}) = {d_recovered}")
print(f"      Matches known d2={d2}: {d_recovered == d2}")

# Step 3: confirm the recovered key actually decrypts correctly.
test_msg      = "RSA!"
test_enc      = rsa_encrypt(test_msg, e2, n2)
test_dec      = rsa_decrypt(test_enc, d_recovered, n2)   # use recovered d
print(f"      Encrypt '{test_msg}' -> decrypt with recovered d -> '{test_dec}'")
assert test_dec == test_msg, "Decryption with recovered key failed!"
print("      Decryption with recovered key: OK")


# =============================================================================
# EXERCISE 3 -- Comparing Symmetric and Asymmetric Encryption
# =============================================================================

# -----------------------------------------------------------------------------
# 3(a) -- Speed comparison: XOR vs toy RSA
# -----------------------------------------------------------------------------

print("\n=== Exercise 3(a) ===")

long_msg = "A" * 1000

# Measure XOR cipher time.
start    = perf_counter()
xor_enc  = xor_cipher(long_msg, 42)
xor_time = perf_counter() - start

# Measure toy RSA time.
start    = perf_counter()
rsa_enc  = rsa_encrypt(long_msg, e2, n2)
rsa_time = perf_counter() - start

# Print both durations and their ratio.
print(f"XOR cipher : {xor_time:.6f} s  (1000 characters)")
print(f"Toy RSA    : {rsa_time:.6f} s  (1000 characters)")
print(f"RSA is ~{rsa_time / xor_time:.0f}x slower than XOR")

# Why HTTPS uses a hybrid approach:
# - RSA (asymmetric) is used once at connection setup to securely transmit
#   a short random session key (e.g. 256 bits = 32 bytes).  Only a handful
#   of expensive modular exponentiations are needed.
# - AES (symmetric, conceptually like XOR but cryptographically strong)
#   then encrypts all actual web traffic using that session key.  It is
#   orders of magnitude faster than RSA per byte.
# Result: the key-exchange security of asymmetric crypto + the speed of
# symmetric bulk encryption.


# -----------------------------------------------------------------------------
# 3(b) -- The key-distribution problem
# -----------------------------------------------------------------------------

print("\n=== Exercise 3(b) ===")

# (i) Number of pairwise keys needed for n people using symmetric crypto:
#     Every unordered pair {i, j} needs its own shared key.
#     Number of pairs = C(n, 2) = n*(n-1)/2.
print("  People (n) | Keys needed  [n*(n-1)/2]")
print("  -----------+---------------------------")
for n_people in [2, 5, 10, 100]:
    keys = n_people * (n_people - 1) // 2
    print(f"  {n_people:>10} | {keys:>10,}")

# General formula: n*(n-1)/2.  Grows quadratically -- for 1 million users
# you would need ~5 * 10^11 (500 billion) shared keys.

# (ii) Public-key solution:
# Each person generates one key pair (public + private).
# To send a private message to Bob, Alice encrypts with Bob's public key.
# Only Bob's private key can decrypt it.
# No pre-shared secret is ever needed.
# Total keys in the system: 2n (n public + n private) instead of n*(n-1)/2.
print("\n  With public-key crypto: each person holds exactly 1 key pair.")
print("  Total keys for n people: 2n  (vs n*(n-1)/2 for symmetric).")
print("  No secure channel needed to distribute keys in advance.")