# Copyright 2018 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for mathematics_dataset.sample.ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
from absl.testing import absltest
from mathematics_dataset.sample import ops
from six.moves import range
import sympy


class OpsTest(absltest.TestCase):

  def testNeg(self):
    op = ops.Neg(2)
    self.assertEqual(str(op), '-2')
    self.assertEqual(op.sympy(), -2)

    op = ops.Add(ops.Neg(2), 3)
    self.assertEqual(str(op), '-2 + 3')
    self.assertEqual(op.sympy(), 1)

    op = ops.Add(3, ops.Neg(2))
    self.assertEqual(str(op), '3 - 2')
    self.assertEqual(op.sympy(), 1)

    op = ops.Add(ops.Add(ops.Neg(2), 5), 3)
    self.assertEqual(str(op), '-2 + 5 + 3')
    self.assertEqual(op.sympy(), 6)

    op = ops.Add(3, ops.Add(ops.Identity(ops.Neg(2)), 5))
    self.assertEqual(str(op), '3 - 2 + 5')
    self.assertEqual(op.sympy(), 6)

    op = ops.Add(3, ops.Add(2, ops.Neg(5)))
    self.assertEqual(str(op), '3 + 2 - 5')
    self.assertEqual(op.sympy(), 0)

  def testAdd(self):
    add = ops.Add()
    self.assertEqual(str(add), '0')
    self.assertEqual(add.sympy(), 0)

    add = ops.Add(2, 3)
    self.assertEqual(str(add), '2 + 3')
    self.assertEqual(add.sympy(), 5)

    add = ops.Add(ops.Add(1, 2), 3)
    self.assertEqual(str(add), '1 + 2 + 3')
    self.assertEqual(add.sympy(), 6)

  def testSub(self):
    sub = ops.Sub(2, 3)
    self.assertEqual(str(sub), '2 - 3')
    self.assertEqual(sub.sympy(), -1)

    sub = ops.Sub(ops.Sub(1, 2), 3)
    self.assertEqual(str(sub), '1 - 2 - 3')
    self.assertEqual(sub.sympy(), -4)

    sub = ops.Sub(1, ops.Sub(2, 3))
    self.assertEqual(str(sub), '1 - (2 - 3)')
    self.assertEqual(sub.sympy(), 2)

    sub = ops.Sub(ops.Neg(1), 2)
    self.assertEqual(str(sub), '-1 - 2')
    self.assertEqual(sub.sympy(), -3)

  def testMul(self):
    mul = ops.Mul()
    self.assertEqual(str(mul), '1')
    self.assertEqual(mul.sympy(), 1)

    mul = ops.Mul(2, 3)
    self.assertEqual(str(mul), '2*3')
    self.assertEqual(mul.sympy(), 6)

    mul = ops.Mul(ops.Identity(ops.Constant(-2)), 3)
    self.assertEqual(str(mul), '-2*3')
    self.assertEqual(mul.sympy(), -6)

    mul = ops.Mul(ops.Add(1, 2), 3)
    self.assertEqual(str(mul), '(1 + 2)*3')
    self.assertEqual(mul.sympy(), 9)

    mul = ops.Mul(ops.Mul(2, 3), 5)
    self.assertEqual(str(mul), '2*3*5')
    self.assertEqual(mul.sympy(), 30)

    # TODO(b/124038946): reconsider how we want brackets in these cases:
#     mul = ops.Mul(ops.Div(2, 3), 5)
#     self.assertEqual(str(mul), '(2/3)*5')
#     self.assertEqual(mul.sympy(), sympy.Rational(10, 3))
#
#     mul = ops.Mul(sympy.Rational(2, 3), 5)
#     self.assertEqual(str(mul), '(2/3)*5')
#     self.assertEqual(mul.sympy(), sympy.Rational(10, 3))

  def testDiv(self):
    div = ops.Div(2, 3)
    self.assertEqual(str(div), '2/3')
    self.assertEqual(div.sympy(), sympy.Rational(2, 3))

    div = ops.Div(2, sympy.Rational(4, 5))
    self.assertEqual(str(div), '2/(4/5)')
    self.assertEqual(div.sympy(), sympy.Rational(5, 2))

    div = ops.Div(1, ops.Div(2, 3))
    self.assertEqual(str(div), '1/(2/3)')
    self.assertEqual(div.sympy(), sympy.Rational(3, 2))

    div = ops.Div(ops.Div(2, 3), 4)
    self.assertEqual(str(div), '(2/3)/4')
    self.assertEqual(div.sympy(), sympy.Rational(1, 6))

    div = ops.Div(2, ops.Mul(3, 4))
    self.assertEqual(str(div), '2/(3*4)')

    div = ops.Div(2, sympy.Function('f')(sympy.Symbol('x')))
    self.assertEqual(str(div), '2/f(x)')

  def testPow(self):
    pow_ = ops.Pow(2, 3)
    self.assertEqual(str(pow_), '2**3')
    self.assertEqual(pow_.sympy(), 8)

    pow_ = ops.Pow(4, sympy.Rational(1, 2))
    self.assertEqual(str(pow_), '4**(1/2)')
    self.assertEqual(pow_.sympy(), 2)

    pow_ = ops.Pow(sympy.Rational(1, 2), 3)
    self.assertEqual(str(pow_), '(1/2)**3')
    self.assertEqual(pow_.sympy(), 1/8)

    pow_ = ops.Pow(3, ops.Pow(2, 1))
    self.assertEqual(str(pow_), '3**(2**1)')
    self.assertEqual(pow_.sympy(), 9)

    pow_ = ops.Pow(ops.Pow(2, 3), 4)
    self.assertEqual(str(pow_), '(2**3)**4')
    self.assertEqual(pow_.sympy(), 4096)

    pow_ = ops.Pow(-5, 2)
    self.assertEqual(str(pow_), '(-5)**2')
    self.assertEqual(pow_.sympy(), 25)

  def testEq(self):
    op = ops.Eq(ops.Add(2, 3), 4)
    self.assertEqual(str(op), '2 + 3 = 4')
    self.assertEqual(op.sympy(), False)

  def testDescendants(self):
    constants = [ops.Constant(i) for i in range(6)]

    # (1 + 2*3**4) / 5 - 6
    expression = ops.Sub(
        ops.Div(
            ops.Add(
                constants[0],
                ops.Mul(
                    constants[1],
                    ops.Pow(
                        constants[2],
                        constants[3]))),
            constants[4]),
        constants[5])
    descendants = expression.descendants()
    descendants = ops._flatten(descendants)

    for constant in constants:
      self.assertIn(constant, descendants)
      self.assertEqual(descendants.count(constant), 1)

    # Also test top-level.
    self.assertEqual(constants[0].descendants(), [constants[0]])

    # Also general structure.
    constant = ops.Constant(3)
    expression = ops.Neg(constant)
    self.assertEqual(set(expression.descendants()), set([constant, expression]))

  def testNumberConstants(self):
    constant = ops.Constant(3)
    expression = ops.Neg(constant)
    constants = ops.number_constants([expression])
    self.assertEqual(constants, [constant])


if __name__ == '__main__':
  absltest.main()