import operator
import unittest
from mpyc import gfpx
from mpyc import finfields


class Arithmetic(unittest.TestCase):

    def setUp(self):
        self.f2 = finfields.GF(gfpx.GFpX(2)(2))
        self.f256 = finfields.GF(gfpx.GFpX(2)(283))  # AES polynomial (283)_2 = X^8+X^4+X^3+X+1

        self.f2p = finfields.GF(2)
        self.f19 = finfields.GF(19)    # 19 % 4 = 3
        self.f101 = finfields.GF(101)  # 101 % 4 = 1
        self.f19.is_signed = False
        self.f101.is_signed = False

        self.f27 = finfields.GF(gfpx.GFpX(3)(46))  # irreducible polynomial X^3 + 2X^2 + 1
        self.f81 = finfields.GF(gfpx.GFpX(3)(115))  # irreducible polynomial X^4 + X^3 + 2X + 1

    def test_field_caching(self):
        self.assertNotEqual(self.f2(1), self.f2p(1))
        f2_cached = finfields.GF(gfpx.GFpX(2)(2))
        self.assertEqual(self.f2(1), f2_cached(1))
        self.assertEqual(self.f2(1) * f2_cached(1), self.f2(1))
        f256_cached = finfields.GF(gfpx.GFpX(2)(283))
        self.assertEqual(self.f256(3), f256_cached(3))
        self.assertEqual(self.f256(3) * f256_cached(3), self.f256(5))
        self.assertEqual(self.f256(48) * f256_cached(16), self.f256(45))

        f2_cached = finfields.GF(2)
        self.assertEqual(self.f2p(1), f2_cached(1))
        self.assertEqual(self.f2p(1) * f2_cached(1), 1)
        f19_cached = finfields.GF(19)
        self.assertEqual(self.f19(3), f19_cached(3))
        self.assertEqual(self.f19(3) * f19_cached(3), 9)
        f101_cached = finfields.GF(101)
        self.assertEqual(self.f101(3), f101_cached(3))
        self.assertEqual(self.f101(3) * f101_cached(23), 69)

    def test_to_from_bytes(self):
        for F in [self.f2, self.f256, self.f2p, self.f19, self.f101]:
            self.assertEqual(F.from_bytes(F.to_bytes([])), [])
            self.assertEqual(F.from_bytes(F.to_bytes([0, 1])), [0, 1])
            self.assertEqual(F.from_bytes(F.to_bytes([F.order - 1])), [F.order - 1])

    def test_find_prime_root(self):
        f = finfields.find_prime_root
        pnw = f(2, False)
        self.assertEqual(pnw, (2, 1, 1))
        pnw = f(2)
        self.assertEqual(pnw, (3, 2, -1))
        pnw = f(5, n=1)
        self.assertEqual(pnw, (19, 1, 1))
        pnw = f(5, n=2)
        self.assertEqual(pnw, (19, 2, -1))
        p, n, w = f(5, n=3)
        self.assertEqual((w**3) % p, 1)
        p, n, w = f(10, n=4)
        self.assertEqual((w**n) % p, 1)

    def test_f2(self):
        f2 = self.f2
        self.assertFalse(f2(0))
        self.assertTrue(f2(1))
        self.assertEqual(f2(1) + f2(0), f2(0) + f2(1))
        self.assertEqual(1 + f2(0), 0 + f2(1))
        self.assertEqual(1 + f2(1), 0)
        self.assertEqual(1 - f2(1), 0)
        self.assertEqual(f2(1) / f2(1), f2(1))
        self.assertEqual(bool(f2(0)), False)
        self.assertEqual(bool(f2(1)), True)

        a = f2(1)
        b = f2(1)
        a += b
        self.assertEqual(a, f2(0))
        a -= b
        self.assertEqual(a, f2(1))
        a *= b
        self.assertEqual(a, f2(1))
        a /= b
        self.assertEqual(a, f2(1))

    def test_f256(self):
        f256 = self.f256
        self.assertFalse(f256(0))
        self.assertTrue(f256(1))
        self.assertEqual(f256(1) + 0, f256(0) + f256(1))
        self.assertEqual(f256(1) + 1, f256(0))
        self.assertEqual(f256(3) * 0, f256(0))
        self.assertEqual(f256(3) * 1, f256(3))
        self.assertEqual(f256(16) * f256(16), f256(27))
        self.assertEqual(f256(32) * f256(16), f256(54))
        self.assertEqual(f256(57) * f256(67), f256(137))
        self.assertEqual(f256(67) * f256(57), f256(137))
        self.assertEqual(f256(137) / f256(57), f256(67))
        self.assertEqual(f256(137) / f256(67), f256(57))

        a = f256(0)
        b = f256(1)
        a += b
        self.assertEqual(a, f256(1))
        a += 1
        self.assertEqual(a, f256(0))
        a -= b
        self.assertEqual(a, f256(1))
        a *= b
        self.assertEqual(a, f256(1))
        a *= 1
        self.assertEqual(a, f256(1))
        a /= 1
        self.assertEqual(a, f256(1))
        a <<= 0
        a >>= 0
        self.assertEqual(a, f256(1))
        a <<= 2
        self.assertEqual(a, f256(4))
        a >>= 2
        self.assertEqual(a, f256(1))

        a = f256(3)  # generator X + 1
        s = [int((a**i).value) for i in range(255)]
        self.assertListEqual(sorted(s), list(range(1, 256)))
        s = [int((a**i).value) for i in range(-255, 0)]
        self.assertListEqual(sorted(s), list(range(1, 256)))

        f256 = finfields.GF(gfpx.GFpX(2)(391))  # primitive polynomial X^8 + X^7 + X^2 + X + 1
        a = f256(2)  # generator X
        s = [int((a**i).value) for i in range(255)]
        self.assertListEqual(sorted(s), list(range(1, 256)))

        a = f256(177)
        self.assertTrue(a.is_sqr())
        self.assertEqual(a.sqrt()**2, a)
        a = f256(255)
        self.assertEqual(a.sqrt()**2, a)

    def test_f2p(self):
        f2 = self.f2p
        self.assertEqual(f2.nth, 1)
        self.assertEqual(f2.root, 1)
        self.assertEqual(f2.root ** f2.nth, 1)
        self.assertFalse(f2(0))
        self.assertTrue(f2(1))
        self.assertEqual(f2(1) + f2(0), f2(0) + f2(1))
        self.assertEqual(1 + f2(0), 0 + f2(1))
        self.assertEqual(1 + f2(1), 0)
        self.assertEqual(1 - f2(1), 0)
        self.assertEqual(f2(1) / f2(1), 1)
        self.assertEqual(f2(1).sqrt(), 1)
        self.assertEqual(bool(f2(0)), False)
        self.assertEqual(bool(f2(1)), True)

        a = f2(1)
        b = f2(1)
        a += b
        self.assertEqual(a, 0)
        a -= b
        self.assertEqual(a, 1)
        a *= b
        self.assertEqual(a, 1)
        a /= b
        self.assertEqual(a, 1)

    def test_f19(self):
        f19 = self.f19
        self.assertEqual(f19.nth, 2)
        self.assertEqual(f19.root, 19 - 1)
        self.assertEqual(f19(f19.root) ** f19.nth, 1)
        self.assertEqual(bool(f19(0)), False)
        self.assertEqual(bool(f19(1)), True)
        self.assertEqual(bool(f19(-1)), True)

        a = f19(12)
        b = f19(11)
        c = a + b
        self.assertEqual(c, (a.value + b.value) % 19)
        c = c - b
        self.assertEqual(c, a)
        c = c - a
        self.assertEqual(c, 0)
        self.assertEqual(a / a, 1)
        self.assertEqual((f19(1).sqrt())**2, 1)
        self.assertEqual(((a**2).sqrt())**2, a**2)
        self.assertNotEqual(((a**2).sqrt())**2, -a**2)
        self.assertEqual(a**f19.modulus, a)
        b = -a
        self.assertEqual(-b, a)

        a = f19(12)
        b = f19(11)
        a += b
        self.assertEqual(a, 4)
        a -= b
        self.assertEqual(a, 12)
        a *= b
        self.assertEqual(a, 18)
        a <<= 2
        self.assertEqual(a, 15)
        a <<= 0
        self.assertEqual(a, 15)
        a >>= 2
        self.assertEqual(a, 18)
        a >>= 0
        self.assertEqual(a, 18)

    def test_f101(self):
        f101 = self.f101
        self.assertEqual(f101.nth, 2)
        self.assertEqual(f101.root, 101 - 1)
        self.assertEqual(f101(f101.root) ** f101.nth, 1)

        a = f101(12)
        b = f101(11)
        c = a + b
        self.assertEqual(c, (a.value + b.value) % 101)
        c = c - b
        self.assertEqual(c, a)
        c = c - a
        self.assertEqual(c, 0)
        self.assertEqual(a / a, 1)
        self.assertEqual((f101(1).sqrt())**2, 1)
        self.assertEqual((f101(4).sqrt())**2, 4)
        self.assertEqual(((a**2).sqrt())**2, a**2)
        self.assertNotEqual(((a**2).sqrt())**2, -a**2)
        self.assertEqual(a**f101.modulus, a)
        b = -a
        self.assertEqual(-b, a)

        a = f101(120)
        b = f101(110)
        a += b
        self.assertEqual(a, 28)
        a -= b
        self.assertEqual(a, 19)
        a *= b
        self.assertEqual(a, 70)
        a /= b
        self.assertEqual(a, 19)

    def test_f27(self):
        f27 = self.f27  # 27 == 3 (mod 4)
        a = f27(10)
        self.assertTrue((a**2).is_sqr())
        self.assertFalse((-a**2).is_sqr())
        b = (a**2).sqrt()
        self.assertEqual(b**2, a**2)
        b = (a**2).sqrt(INV=True)
        self.assertEqual((a * b)**2, 1)

    def test_f81(self):
        f81 = self.f81  # 81 == 1 (mod 4)
        a = f81(21)
        self.assertTrue((a**2).is_sqr())
        self.assertTrue((-a**2).is_sqr())
        b = (a**2).sqrt()
        self.assertEqual(b**2, a**2)
        b = (a**2).sqrt(INV=True)
        self.assertEqual((a * b)**2, 1)

    def test_operatorerrors(self):
        f2 = self.f2
        f2p = self.f2p
        f256 = self.f256
        f19 = self.f19
        self.assertRaises(TypeError, operator.add, f2(1), f2p(2))
        self.assertRaises(TypeError, operator.iadd, f2(1), f2p(2))
        self.assertRaises(TypeError, operator.sub, f2(1), f256(2))
        self.assertRaises(TypeError, operator.isub, f2(1), f256(2))
        self.assertRaises(TypeError, operator.mul, f2(1), f19(2))
        self.assertRaises(TypeError, operator.imul, f2(1), f19(2))
        self.assertRaises(TypeError, operator.truediv, f256(1), f19(2))
        self.assertRaises(TypeError, operator.itruediv, f256(1), f19(2))
        self.assertRaises(TypeError, operator.truediv, 3.14, f19(2))
        self.assertRaises(TypeError, operator.lshift, f2(1), f2(1))
        self.assertRaises(TypeError, operator.ilshift, f2(1), f2(1))
        self.assertRaises(TypeError, operator.lshift, 1, f2(1))
        self.assertRaises(TypeError, operator.rshift, f19(1), f19(1))
        self.assertRaises(TypeError, operator.irshift, f19(1), f19(1))
        self.assertRaises(TypeError, operator.irshift, f256(1), f256(1))
        self.assertRaises(TypeError, operator.pow, f2(1), f19(2))
        self.assertRaises(TypeError, operator.pow, f19(1), 3.14)