""" Testing flag combining operators on our flags instances. """
import operator
from unittest import TestCase

from flags import Flags


class MyOtherFlags(Flags):
    of0 = ()


class MyFlags(Flags):
    f0 = ()
    f1 = ()
    f2 = ()


no_flags = MyFlags.no_flags
all_flags = MyFlags.all_flags
f0 = MyFlags.f0
f1 = MyFlags.f1
f2 = MyFlags.f2
f01 = MyFlags.f0 | MyFlags.f1
f02 = MyFlags.f0 | MyFlags.f2
f12 = MyFlags.f1 | MyFlags.f2


class TestArithmetic(TestCase):
    def test_member_bits(self):
        self.assertEqual(int(MyOtherFlags.of0), 1)

        self.assertEqual(int(no_flags), 0)
        self.assertEqual(int(all_flags), 7)
        self.assertEqual(int(f0), 1)
        self.assertEqual(int(f1), 2)
        self.assertEqual(int(f2), 4)
        self.assertEqual(int(f01), 3)
        self.assertEqual(int(f02), 5)
        self.assertEqual(int(f12), 6)

    def test_contains(self):
        self.assertNotIn(MyOtherFlags.of0, MyFlags.all_flags)
        self.assertNotIn(MyOtherFlags.of0, MyFlags.f0)
        self.assertNotIn(False, MyFlags.f0)
        self.assertNotIn(True, MyFlags.f0)
        self.assertNotIn('', MyFlags.f0)
        self.assertNotIn('my_string', MyFlags.f0)
        self.assertNotIn(4, MyFlags.f0)
        self.assertNotIn(5.5, MyFlags.f0)
        self.assertNotIn(None, MyFlags.f0)

        # same test cases as in case of operator.__le__(item, flags)
        self.assertTrue(no_flags in no_flags)
        self.assertTrue(no_flags in all_flags)
        self.assertTrue(no_flags in f0)
        self.assertTrue(no_flags in f1)
        self.assertTrue(no_flags in f2)
        self.assertTrue(no_flags in f01)
        self.assertTrue(no_flags in f02)
        self.assertTrue(no_flags in f12)

        self.assertFalse(f0 in no_flags)
        self.assertTrue(f0 in all_flags)
        self.assertTrue(f0 in f0)
        self.assertFalse(f0 in f1)
        self.assertFalse(f0 in f2)
        self.assertTrue(f0 in f01)
        self.assertTrue(f0 in f02)
        self.assertFalse(f0 in f12)

        self.assertFalse(f01 in no_flags)
        self.assertTrue(f01 in all_flags)
        self.assertFalse(f01 in f0)
        self.assertFalse(f01 in f1)
        self.assertFalse(f01 in f2)
        self.assertTrue(f01 in f01)
        self.assertFalse(f01 in f02)
        self.assertFalse(f01 in f12)

        self.assertTrue(no_flags in all_flags)
        self.assertTrue(all_flags in all_flags)
        self.assertTrue(f0 in all_flags)
        self.assertTrue(f1 in all_flags)
        self.assertTrue(f2 in all_flags)
        self.assertTrue(f01 in all_flags)
        self.assertTrue(f02 in all_flags)
        self.assertTrue(f12 in all_flags)

    def test_is_disjoint(self):
        self.assertTrue(no_flags.is_disjoint(no_flags))
        self.assertTrue(no_flags.is_disjoint(all_flags))
        self.assertTrue(no_flags.is_disjoint(f0))
        self.assertTrue(no_flags.is_disjoint(f1))
        self.assertTrue(no_flags.is_disjoint(f2))
        self.assertTrue(no_flags.is_disjoint(f01))
        self.assertTrue(no_flags.is_disjoint(f02))
        self.assertTrue(no_flags.is_disjoint(f12))

        self.assertTrue(f0.is_disjoint(no_flags))
        self.assertFalse(f0.is_disjoint(all_flags))
        self.assertFalse(f0.is_disjoint(f0))
        self.assertTrue(f0.is_disjoint(f1))
        self.assertTrue(f0.is_disjoint(f2))
        self.assertFalse(f0.is_disjoint(f01))
        self.assertFalse(f0.is_disjoint(f02))
        self.assertTrue(f0.is_disjoint(f12))

        self.assertTrue(f01.is_disjoint(no_flags))
        self.assertFalse(f01.is_disjoint(all_flags))
        self.assertFalse(f01.is_disjoint(f0))
        self.assertFalse(f01.is_disjoint(f1))
        self.assertTrue(f01.is_disjoint(f2))
        self.assertFalse(f01.is_disjoint(f01))
        self.assertFalse(f01.is_disjoint(f02))
        self.assertFalse(f01.is_disjoint(f12))

        self.assertTrue(all_flags.is_disjoint(no_flags))
        self.assertFalse(all_flags.is_disjoint(all_flags))
        self.assertFalse(all_flags.is_disjoint(f0))
        self.assertFalse(all_flags.is_disjoint(f1))
        self.assertFalse(all_flags.is_disjoint(f2))
        self.assertFalse(all_flags.is_disjoint(f01))
        self.assertFalse(all_flags.is_disjoint(f02))
        self.assertFalse(all_flags.is_disjoint(f12))

    def _test_incompatible_types_fail(self, operator_):
        for other in (MyOtherFlags.of0, False, True, '', 'my_string', 4, 5.5, None):
            with self.assertRaises(TypeError, msg='other operand: %r' % other):
                operator_(f0, other)

    def test_or(self):
        self._test_incompatible_types_fail(operator.__or__)

        self.assertEqual(no_flags | no_flags, no_flags)
        self.assertEqual(no_flags | all_flags, all_flags)
        self.assertEqual(no_flags | f0, f0)
        self.assertEqual(no_flags | f1, f1)
        self.assertEqual(no_flags | f2, f2)
        self.assertEqual(no_flags | f01, f01)
        self.assertEqual(no_flags | f02, f02)
        self.assertEqual(no_flags | f12, f12)

        self.assertEqual(f0 | no_flags, f0)
        self.assertEqual(f0 | all_flags, all_flags)
        self.assertEqual(f0 | f0, f0)
        self.assertEqual(f0 | f1, f01)
        self.assertEqual(f0 | f2, f02)
        self.assertEqual(f0 | f01, f01)
        self.assertEqual(f0 | f02, f02)
        self.assertEqual(f0 | f12, all_flags)

        self.assertEqual(f01 | no_flags, f01)
        self.assertEqual(f01 | all_flags, all_flags)
        self.assertEqual(f01 | f0, f01)
        self.assertEqual(f01 | f1, f01)
        self.assertEqual(f01 | f2, all_flags)
        self.assertEqual(f01 | f01, f01)
        self.assertEqual(f01 | f02, all_flags)
        self.assertEqual(f01 | f12, all_flags)

        self.assertEqual(all_flags | no_flags, all_flags)
        self.assertEqual(all_flags | all_flags, all_flags)
        self.assertEqual(all_flags | f0, all_flags)
        self.assertEqual(all_flags | f1, all_flags)
        self.assertEqual(all_flags | f2, all_flags)
        self.assertEqual(all_flags | f01, all_flags)
        self.assertEqual(all_flags | f02, all_flags)
        self.assertEqual(all_flags | f12, all_flags)

    def test_xor(self):
        self._test_incompatible_types_fail(operator.__xor__)

        self.assertEqual(no_flags ^ no_flags, no_flags)
        self.assertEqual(no_flags ^ all_flags, all_flags)
        self.assertEqual(no_flags ^ f0, f0)
        self.assertEqual(no_flags ^ f1, f1)
        self.assertEqual(no_flags ^ f2, f2)
        self.assertEqual(no_flags ^ f01, f01)
        self.assertEqual(no_flags ^ f02, f02)
        self.assertEqual(no_flags ^ f12, f12)

        self.assertEqual(f0 ^ no_flags, f0)
        self.assertEqual(f0 ^ all_flags, f12)
        self.assertEqual(f0 ^ f0, no_flags)
        self.assertEqual(f0 ^ f1, f01)
        self.assertEqual(f0 ^ f2, f02)
        self.assertEqual(f0 ^ f01, f1)
        self.assertEqual(f0 ^ f02, f2)
        self.assertEqual(f0 ^ f12, all_flags)

        self.assertEqual(f01 ^ no_flags, f01)
        self.assertEqual(f01 ^ all_flags, f2)
        self.assertEqual(f01 ^ f0, f1)
        self.assertEqual(f01 ^ f1, f0)
        self.assertEqual(f01 ^ f2, all_flags)
        self.assertEqual(f01 ^ f01, no_flags)
        self.assertEqual(f01 ^ f02, f12)
        self.assertEqual(f01 ^ f12, f02)

        self.assertEqual(all_flags ^ no_flags, all_flags)
        self.assertEqual(all_flags ^ all_flags, no_flags)
        self.assertEqual(all_flags ^ f0, f12)
        self.assertEqual(all_flags ^ f1, f02)
        self.assertEqual(all_flags ^ f2, f01)
        self.assertEqual(all_flags ^ f01, f2)
        self.assertEqual(all_flags ^ f02, f1)
        self.assertEqual(all_flags ^ f12, f0)

    def test_and(self):
        self._test_incompatible_types_fail(operator.__and__)

        self.assertEqual(no_flags & no_flags, no_flags)
        self.assertEqual(no_flags & all_flags, no_flags)
        self.assertEqual(no_flags & f0, no_flags)
        self.assertEqual(no_flags & f1, no_flags)
        self.assertEqual(no_flags & f2, no_flags)
        self.assertEqual(no_flags & f01, no_flags)
        self.assertEqual(no_flags & f02, no_flags)
        self.assertEqual(no_flags & f12, no_flags)

        self.assertEqual(f0 & no_flags, no_flags)
        self.assertEqual(f0 & all_flags, f0)
        self.assertEqual(f0 & f0, f0)
        self.assertEqual(f0 & f1, no_flags)
        self.assertEqual(f0 & f2, no_flags)
        self.assertEqual(f0 & f01, f0)
        self.assertEqual(f0 & f02, f0)
        self.assertEqual(f0 & f12, no_flags)

        self.assertEqual(f01 & no_flags, no_flags)
        self.assertEqual(f01 & all_flags, f01)
        self.assertEqual(f01 & f0, f0)
        self.assertEqual(f01 & f1, f1)
        self.assertEqual(f01 & f2, no_flags)
        self.assertEqual(f01 & f01, f01)
        self.assertEqual(f01 & f02, f0)
        self.assertEqual(f01 & f12, f1)

        self.assertEqual(all_flags & no_flags, no_flags)
        self.assertEqual(all_flags & all_flags, all_flags)
        self.assertEqual(all_flags & f0, f0)
        self.assertEqual(all_flags & f1, f1)
        self.assertEqual(all_flags & f2, f2)
        self.assertEqual(all_flags & f01, f01)
        self.assertEqual(all_flags & f02, f02)
        self.assertEqual(all_flags & f12, f12)

    def test_sub(self):
        self._test_incompatible_types_fail(operator.__sub__)

        self.assertEqual(no_flags - no_flags, no_flags)
        self.assertEqual(no_flags - all_flags, no_flags)
        self.assertEqual(no_flags - f0, no_flags)
        self.assertEqual(no_flags - f1, no_flags)
        self.assertEqual(no_flags - f2, no_flags)
        self.assertEqual(no_flags - f01, no_flags)
        self.assertEqual(no_flags - f02, no_flags)
        self.assertEqual(no_flags - f12, no_flags)

        self.assertEqual(f0 - no_flags, f0)
        self.assertEqual(f0 - all_flags, no_flags)
        self.assertEqual(f0 - f0, no_flags)
        self.assertEqual(f0 - f1, f0)
        self.assertEqual(f0 - f2, f0)
        self.assertEqual(f0 - f01, no_flags)
        self.assertEqual(f0 - f02, no_flags)
        self.assertEqual(f0 - f12, f0)

        self.assertEqual(f01 - no_flags, f01)
        self.assertEqual(f01 - all_flags, no_flags)
        self.assertEqual(f01 - f0, f1)
        self.assertEqual(f01 - f1, f0)
        self.assertEqual(f01 - f2, f01)
        self.assertEqual(f01 - f01, no_flags)
        self.assertEqual(f01 - f02, f1)
        self.assertEqual(f01 - f12, f0)

        self.assertEqual(all_flags - no_flags, all_flags)
        self.assertEqual(all_flags - all_flags, no_flags)
        self.assertEqual(all_flags - f0, f12)
        self.assertEqual(all_flags - f1, f02)
        self.assertEqual(all_flags - f2, f01)
        self.assertEqual(all_flags - f01, f2)
        self.assertEqual(all_flags - f02, f1)
        self.assertEqual(all_flags - f12, f0)

    def test_eq(self):
        self.assertFalse(MyFlags.f0 == MyOtherFlags.of0)
        self.assertFalse(MyFlags.f0 == False)
        self.assertFalse(MyFlags.f0 == True)
        self.assertFalse(MyFlags.f0 == '')
        self.assertFalse(MyFlags.f0 == 'my_string')
        self.assertFalse(MyFlags.f0 == None)

        self.assertTrue(no_flags == no_flags)
        self.assertTrue(all_flags == all_flags)
        self.assertTrue(f0 == f0)
        self.assertTrue(f1 == f1)
        self.assertTrue(f2 == f2)
        self.assertTrue(f01 == f01)
        self.assertTrue(f02 == f02)
        self.assertTrue(f12 == f12)

        self.assertFalse(f0 == no_flags)
        self.assertFalse(f0 == all_flags)
        self.assertTrue(f0 == f0)
        self.assertFalse(f0 == f1)
        self.assertFalse(f0 == f2)
        self.assertFalse(f0 == f01)
        self.assertFalse(f0 == f02)
        self.assertFalse(f0 == f12)

        self.assertFalse(f01 == no_flags)
        self.assertFalse(f01 == all_flags)
        self.assertFalse(f01 == f0)
        self.assertFalse(f01 == f1)
        self.assertFalse(f01 == f2)
        self.assertTrue(f01 == f01)
        self.assertFalse(f01 == f02)
        self.assertFalse(f01 == f12)

    def test_ne(self):
        self.assertTrue(MyFlags.f0 != MyOtherFlags.of0)
        self.assertTrue(MyFlags.f0 != False)
        self.assertTrue(MyFlags.f0 != True)
        self.assertTrue(MyFlags.f0 != '')
        self.assertTrue(MyFlags.f0 != 'my_string')
        self.assertTrue(MyFlags.f0 != None)

        self.assertFalse(no_flags != no_flags)
        self.assertFalse(all_flags != all_flags)
        self.assertFalse(f0 != f0)
        self.assertFalse(f1 != f1)
        self.assertFalse(f2 != f2)
        self.assertFalse(f01 != f01)
        self.assertFalse(f02 != f02)
        self.assertFalse(f12 != f12)

        self.assertTrue(f0 != no_flags)
        self.assertTrue(f0 != all_flags)
        self.assertFalse(f0 != f0)
        self.assertTrue(f0 != f1)
        self.assertTrue(f0 != f2)
        self.assertTrue(f0 != f01)
        self.assertTrue(f0 != f02)
        self.assertTrue(f0 != f12)

        self.assertTrue(f01 != no_flags)
        self.assertTrue(f01 != all_flags)
        self.assertTrue(f01 != f0)
        self.assertTrue(f01 != f1)
        self.assertTrue(f01 != f2)
        self.assertFalse(f01 != f01)
        self.assertTrue(f01 != f02)
        self.assertTrue(f01 != f12)

    def test_ge(self):
        self._test_incompatible_types_fail(operator.__ge__)

        self.assertTrue(no_flags >= no_flags)
        self.assertFalse(no_flags >= all_flags)
        self.assertFalse(no_flags >= f0)
        self.assertFalse(no_flags >= f1)
        self.assertFalse(no_flags >= f2)
        self.assertFalse(no_flags >= f01)
        self.assertFalse(no_flags >= f02)
        self.assertFalse(no_flags >= f12)

        self.assertTrue(f0 >= no_flags)
        self.assertFalse(f0 >= all_flags)
        self.assertTrue(f0 >= f0)
        self.assertFalse(f0 >= f1)
        self.assertFalse(f0 >= f2)
        self.assertFalse(f0 >= f01)
        self.assertFalse(f0 >= f02)
        self.assertFalse(f0 >= f12)

        self.assertTrue(f01 >= no_flags)
        self.assertFalse(f01 >= all_flags)
        self.assertTrue(f01 >= f0)
        self.assertTrue(f01 >= f1)
        self.assertFalse(f01 >= f2)
        self.assertTrue(f01 >= f01)
        self.assertFalse(f01 >= f02)
        self.assertFalse(f01 >= f12)

        self.assertFalse(no_flags >= all_flags)
        self.assertTrue(all_flags >= all_flags)
        self.assertFalse(f0 >= all_flags)
        self.assertFalse(f1 >= all_flags)
        self.assertFalse(f2 >= all_flags)
        self.assertFalse(f01 >= all_flags)
        self.assertFalse(f02 >= all_flags)
        self.assertFalse(f12 >= all_flags)

    def test_gt(self):
        self._test_incompatible_types_fail(operator.__gt__)

        self.assertFalse(no_flags > no_flags)
        self.assertFalse(no_flags > all_flags)
        self.assertFalse(no_flags > f0)
        self.assertFalse(no_flags > f1)
        self.assertFalse(no_flags > f2)
        self.assertFalse(no_flags > f01)
        self.assertFalse(no_flags > f02)
        self.assertFalse(no_flags > f12)

        self.assertTrue(f0 > no_flags)
        self.assertFalse(f0 > all_flags)
        self.assertFalse(f0 > f0)
        self.assertFalse(f0 > f1)
        self.assertFalse(f0 > f2)
        self.assertFalse(f0 > f01)
        self.assertFalse(f0 > f02)
        self.assertFalse(f0 > f12)

        self.assertTrue(f01 > no_flags)
        self.assertFalse(f01 > all_flags)
        self.assertTrue(f01 > f0)
        self.assertTrue(f01 > f1)
        self.assertFalse(f01 > f2)
        self.assertFalse(f01 > f01)
        self.assertFalse(f01 > f02)
        self.assertFalse(f01 > f12)

        self.assertFalse(no_flags > all_flags)
        self.assertFalse(all_flags > all_flags)
        self.assertFalse(f0 > all_flags)
        self.assertFalse(f1 > all_flags)
        self.assertFalse(f2 > all_flags)
        self.assertFalse(f01 > all_flags)
        self.assertFalse(f02 > all_flags)
        self.assertFalse(f12 > all_flags)

    def test_le(self):
        self._test_incompatible_types_fail(operator.__le__)

        self.assertTrue(no_flags <= no_flags)
        self.assertTrue(no_flags <= all_flags)
        self.assertTrue(no_flags <= f0)
        self.assertTrue(no_flags <= f1)
        self.assertTrue(no_flags <= f2)
        self.assertTrue(no_flags <= f01)
        self.assertTrue(no_flags <= f02)
        self.assertTrue(no_flags <= f12)

        self.assertFalse(f0 <= no_flags)
        self.assertTrue(f0 <= all_flags)
        self.assertTrue(f0 <= f0)
        self.assertFalse(f0 <= f1)
        self.assertFalse(f0 <= f2)
        self.assertTrue(f0 <= f01)
        self.assertTrue(f0 <= f02)
        self.assertFalse(f0 <= f12)

        self.assertFalse(f01 <= no_flags)
        self.assertTrue(f01 <= all_flags)
        self.assertFalse(f01 <= f0)
        self.assertFalse(f01 <= f1)
        self.assertFalse(f01 <= f2)
        self.assertTrue(f01 <= f01)
        self.assertFalse(f01 <= f02)
        self.assertFalse(f01 <= f12)

        self.assertTrue(no_flags <= all_flags)
        self.assertTrue(all_flags <= all_flags)
        self.assertTrue(f0 <= all_flags)
        self.assertTrue(f1 <= all_flags)
        self.assertTrue(f2 <= all_flags)
        self.assertTrue(f01 <= all_flags)
        self.assertTrue(f02 <= all_flags)
        self.assertTrue(f12 <= all_flags)

    def test_lt(self):
        self._test_incompatible_types_fail(operator.__lt__)

        self.assertFalse(no_flags < no_flags)
        self.assertTrue(no_flags < all_flags)
        self.assertTrue(no_flags < f0)
        self.assertTrue(no_flags < f1)
        self.assertTrue(no_flags < f2)
        self.assertTrue(no_flags < f01)
        self.assertTrue(no_flags < f02)
        self.assertTrue(no_flags < f12)

        self.assertFalse(f0 < no_flags)
        self.assertTrue(f0 < all_flags)
        self.assertFalse(f0 < f0)
        self.assertFalse(f0 < f1)
        self.assertFalse(f0 < f2)
        self.assertTrue(f0 < f01)
        self.assertTrue(f0 < f02)
        self.assertFalse(f0 < f12)

        self.assertFalse(f01 < no_flags)
        self.assertTrue(f01 < all_flags)
        self.assertFalse(f01 < f0)
        self.assertFalse(f01 < f1)
        self.assertFalse(f01 < f2)
        self.assertFalse(f01 < f01)
        self.assertFalse(f01 < f02)
        self.assertFalse(f01 < f12)

        self.assertTrue(no_flags < all_flags)
        self.assertFalse(all_flags < all_flags)
        self.assertTrue(f0 < all_flags)
        self.assertTrue(f1 < all_flags)
        self.assertTrue(f2 < all_flags)
        self.assertTrue(f01 < all_flags)
        self.assertTrue(f02 < all_flags)
        self.assertTrue(f12 < all_flags)

    def test_invert(self):
        self.assertEqual(~no_flags, all_flags)
        self.assertEqual(~all_flags, no_flags)
        self.assertEqual(~f0, f12)
        self.assertEqual(~f1, f02)
        self.assertEqual(~f2, f01)
        self.assertEqual(~f01, f2)
        self.assertEqual(~f02, f1)
        self.assertEqual(~f12, f0)