from helpers import truncate, truncate3, xor
from copy import deepcopy
from random import shuffle
from scipy import stats
import numpy


def within_range(x):
    return ord('A') <= x <= ord('Z') or ord('a') <= x <= ord('z')


# Given the cipher and key in bytes
# Return the deciphered message as a string
# if a byte of the key represents a null-value (0)
# display "_" as a space holder
def decode(cipher, key):
    r = ''
    c, k = truncate(cipher, key)

    for i, j in zip(c, k):
        if j == 0:
            r+="_"
            continue

        r += chr(i ^ j)

    return r


# Expects four series of bytes representing
# a partially built key and three ciphers
# Mutates the key in place given info from ciphers
def build_key(key, c1, c2, c3):

    xm12, xm13,  xm23 = xor(c1, c2), xor(c1, c3), xor(c2, c3)

    for i in range(0, len(c1)):

        if key[i] != 0:
            continue

        if c1[i] == c2[i] or c1[i] == c3[i] or c2[i] == c3[i]:
            continue

        m12, m13, m23 = xm12[i], xm13[i], xm23[i]

        if within_range(m13) and within_range(m23):
            key[i] = ord(" ") ^ c3[i]
        elif within_range(m12) and within_range(m23):
            key[i] = ord(" ") ^ c2[i]
        elif within_range(m12) and within_range(m13):
            key[i] = ord(" ") ^ c1[i]


# Expects a list stream ciphers in bytes
# generated by the same key, Returns the key in bytes
# Option: iterations=50
# number of keys generated for frequency analysis
# to generate final key
def find_key(streamciphers, iterations=50):

    ciphers = deepcopy(streamciphers)
    n = len(ciphers)
    # key size of longest cipher
    ksize = len(max(ciphers, key=len))
    possiblekeys = []

    for _ in range(iterations):

        shuffle(ciphers)
        k = bytearray(ksize)

        for a in range(n - 2):
            for b in range(a + 1, n - 1):
                for c in range(b + 1, n):
                    x, y, z = truncate3(ciphers[a], ciphers[b], ciphers[c])
                    build_key(k, x, y, z)

        possiblekeys.append(k)

    # finalize key using frequency analysis
    key_array = stats.mode(numpy.array(possiblekeys))[0][0]
    return bytes(list(key_array))