# Copyright 2018 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Prints to stdout different curriculum questions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import textwrap

# Dependency imports
from absl import app
from absl import flags
from absl import logging
from mathematics_dataset import generate_settings
from mathematics_dataset.modules import modules
import six
from six.moves import range


FLAGS = flags.FLAGS

flags.DEFINE_string('filter', '', 'restrict to matching module names')
flags.DEFINE_integer('per_train_module', 10, 'Num of examples per train module')
flags.DEFINE_integer('per_test_module', 10, 'Num of examples per test module')
flags.DEFINE_bool('show_dropped', False, 'Whether to print dropped questions')


filtered_modules = collections.OrderedDict([])
counts = {}


def _make_entropy_fn(level, num_levels):
  """This returns a function that returns a subrange of entropy.

  E.g., if level=1 (medium) and num_levels=3, then the returned function will
  map the range [x, x + y] to [x + y/3, x + 2y/3].

  Args:
    level: Integer in range [0, num_levels - 1].
    num_levels: Number of difficulty levels.

  Returns:
    Function to restrict entropy range.
  """
  lower = level / num_levels
  upper = (level + 1) / num_levels
  def modify_entropy(range_):
    assert len(range_) == 2
    length = range_[1] - range_[0]
    return (range_[0] + lower * length, range_[0] + upper * length)
  return modify_entropy


def _filter_and_flatten(modules_):
  """Returns flattened dict, filtered according to FLAGS."""
  flat = collections.OrderedDict()

  def add(submodules, prefix=None):
    for key, module_or_function in six.iteritems(submodules):
      full_name = prefix + '__' + key if prefix is not None else key
      if isinstance(module_or_function, dict):
        add(module_or_function, full_name)
      else:
        if FLAGS.filter not in full_name:
          continue
        flat[full_name] = module_or_function

  add(modules_)

  # Make sure list of modules are in deterministic order. This is important when
  # generating across multiple machines.
  flat = collections.OrderedDict(
      [(key, flat[key]) for key in sorted(six.iterkeys(flat))])

  return flat


def init_modules(train_split=False):
  """Inits the dicts containing functions for generating modules."""
  if filtered_modules:
    return  # already initialized

  all_modules = collections.OrderedDict([])
  if train_split:
    all_modules['train-easy'] = modules.train(_make_entropy_fn(0, 3))
    all_modules['train-medium'] = modules.train(_make_entropy_fn(1, 3))
    all_modules['train-hard'] = modules.train(_make_entropy_fn(2, 3))
  else:
    all_modules['train'] = modules.train(_make_entropy_fn(0, 1))

  all_modules['interpolate'] = modules.test()
  all_modules['extrapolate'] = modules.test_extra()

  counts['train'] = FLAGS.per_train_module
  counts['train-easy'] = FLAGS.per_train_module // 3
  counts['train-medium'] = FLAGS.per_train_module // 3
  counts['train-hard'] = FLAGS.per_train_module // 3
  counts['interpolate'] = FLAGS.per_test_module
  counts['extrapolate'] = FLAGS.per_test_module

  for regime_, modules_ in six.iteritems(all_modules):
    filtered_modules[regime_] = _filter_and_flatten(modules_)


def sample_from_module(module):
  """Samples a problem, ignoring samples with overly long questions / answers.

  Args:
    module: Callable returning a `Problem`.

  Returns:
    Pair `(problem, num_dropped)`, where `problem` is an instance of `Problem`
    and `num_dropped` is an integer >= 0 indicating the number of samples that
    were dropped.
  """
  num_dropped = 0
  while True:
    problem = module()
    question = str(problem.question)
    if len(question) > generate_settings.MAX_QUESTION_LENGTH:
      num_dropped += 1
      if FLAGS.show_dropped:
        logging.warning('Dropping question: %s', question)
      continue
    answer = str(problem.answer)
    if len(answer) > generate_settings.MAX_ANSWER_LENGTH:
      num_dropped += 1
      if FLAGS.show_dropped:
        logging.warning('Dropping question with answer: %s', answer)
      continue
    return problem, num_dropped


def main(unused_argv):
  """Prints Q&As from modules according to FLAGS.filter."""
  init_modules()

  text_wrapper = textwrap.TextWrapper(
      width=80, initial_indent=' ', subsequent_indent='  ')

  for regime, flat_modules in six.iteritems(filtered_modules):
    per_module = counts[regime]
    for module_name, module in six.iteritems(flat_modules):
      # These magic print constants make the header bold.
      print('\033[1m{}/{}\033[0m'.format(regime, module_name))
      num_dropped = 0
      for _ in range(per_module):
        problem, extra_dropped = sample_from_module(module)
        num_dropped += extra_dropped
        text = text_wrapper.fill(
            '{}  \033[92m{}\033[0m'.format(problem.question, problem.answer))
        print(text)
      if num_dropped > 0:
        logging.warning('Dropped %d examples', num_dropped)


if __name__ == '__main__':
  app.run(main)