# Copyright 2013 The LUCI Authors. All rights reserved.
# Use of this source code is governed under the Apache License, Version 2.0
# that can be found in the LICENSE file.

import base64
import contextlib
import datetime
import json
import logging
import time

import webtest

from google.appengine.datastore import datastore_stub_util
from google.appengine.ext import ndb
from google.appengine.ext import testbed

from components import endpoints_webapp2
from components import utils
from depot_tools import auto_stub

# W0212: Access to a protected member XXX of a client class
# pylint: disable=W0212


def mock_now(test, now, seconds):
  """Mocks utcnow() and ndb properties.

  In particular handles when auto_now and auto_now_add are used.
  """
  now = now + datetime.timedelta(seconds=seconds)
  test.mock(utils, 'utcnow', lambda: now)
  test.mock(ndb.DateTimeProperty, '_now', lambda _: now)
  test.mock(ndb.DateProperty, '_now', lambda _: now.date())
  return now


class TestCase(auto_stub.TestCase):
  """Support class to enable more unit testing in GAE.

  Adds support for:
    - google.appengine.api.mail.send_mail_to_admins().
    - Running task queues.
  """
  # See APP_DIR to the root directory containing index.yaml and queue.yaml. It
  # will be used to assert the indexes and task queues are properly defined. It
  # can be left to None if no index or task queue is used for the test case.
  APP_DIR = None

  # A test can explicitly acknowledge it depends on composite indexes that may
  # not be defined in index.yaml by setting this to True. It is valid only for
  # components unit tests that are running outside of a context of some app
  # (APP_DIR is None in this case). If APP_DIR is provided, GAE testbed silently
  # overwrite index.yaml, and it's not what we want.
  SKIP_INDEX_YAML_CHECK = False

  # If taskqueues are enqueued during the unit test, self.app must be set to a
  # webtest.Test instance. It will be used to do the HTTP post when executing
  # the enqueued tasks via the taskqueue module.
  app = None

  def setUp(self):
    """Initializes the commonly used stubs.

    Using init_all_stubs() costs ~10ms more to run all the tests so only enable
    the ones known to be required. Test cases requiring more stubs can enable
    them in their setUp() function.
    """
    super(TestCase, self).setUp()
    self.testbed = testbed.Testbed()
    self.testbed.activate()

    # If you have a NeedIndexError, here is the switch you need to flip to make
    # the new required indexes to be automatically added. Change
    # train_index_yaml to True to have index.yaml automatically updated, then
    # run your test case. Do not forget to put it back to False.
    train_index_yaml = False

    if self.SKIP_INDEX_YAML_CHECK:
      # See comment for skip_index_yaml_check above.
      self.assertIsNone(self.APP_DIR)

    self.testbed.init_app_identity_stub()
    self.testbed.init_datastore_v3_stub(
        require_indexes=not train_index_yaml and not self.SKIP_INDEX_YAML_CHECK,
        root_path=self.APP_DIR,
        consistency_policy=datastore_stub_util.PseudoRandomHRConsistencyPolicy(
            probability=1))
    self.testbed.init_logservice_stub()
    self.testbed.init_memcache_stub()
    self.testbed.init_modules_stub()

    # Use mocked time in memcache.
    memcache = self.testbed.get_stub(testbed.MEMCACHE_SERVICE_NAME)
    memcache._gettime = lambda: int(utils.time_time())

    # Email support.
    self.testbed.init_mail_stub()
    self.mail_stub = self.testbed.get_stub(testbed.MAIL_SERVICE_NAME)
    self.old_send_to_admins = self.mock(
        self.mail_stub, '_Dynamic_SendToAdmins', self._SendToAdmins)

    self.testbed.init_taskqueue_stub()
    self._taskqueue_stub = self.testbed.get_stub(testbed.TASKQUEUE_SERVICE_NAME)
    self._taskqueue_stub._root_path = self.APP_DIR

    self.testbed.init_user_stub()

  def tearDown(self):
    try:
      if not self.has_failed():
        remaining = self.execute_tasks()
        self.assertEqual(0, remaining,
            'Passing tests must leave behind no pending tasks, found %d.'
            % remaining)
      self.testbed.deactivate()
    finally:
      super(TestCase, self).tearDown()

  def mock_now(self, now, seconds=0):
    return mock_now(self, now, seconds)

  def _SendToAdmins(self, request, *args, **kwargs):
    """Make sure the request is logged.

    See google_appengine/google/appengine/api/mail_stub.py around line 299,
    MailServiceStub._SendToAdmins().
    """
    self.mail_stub._CacheMessage(request)
    return self.old_send_to_admins(request, *args, **kwargs)

  def execute_tasks(self, **kwargs):
    """Executes enqueued tasks that are ready to run and return the number run.

    A task may trigger another task.

    Sadly, taskqueue_stub implementation does not provide a nice way to run
    them so run the pending tasks manually.
    """
    self.assertEqual([None], self._taskqueue_stub._queues.keys())
    ran_total = 0
    while True:
      # Do multiple loops until no task was run.
      ran = 0
      for queue in self._taskqueue_stub.GetQueues():
        if queue['mode'] == 'pull':
          continue
        for task in self._taskqueue_stub.GetTasks(queue['name']):
          # Remove 2 seconds for jitter.
          eta = task['eta_usec'] / 1e6 - 2
          if eta >= time.time():
            continue
          self.assertEqual('POST', task['method'])
          logging.info('Task: %s', task['url'])

          self._post_task(task, **kwargs)
          self._taskqueue_stub.DeleteTask(queue['name'], task['name'])
          ran += 1
      if not ran:
        return ran_total
      ran_total += ran

  def execute_task(self, url, queue_name, payload):
    """Executes a specified task.
    Raise error if the task isn't in the queue.
    """
    task = self._find_task(url, queue_name, payload)
    expected = {'url': url, 'queue_name': queue_name, 'payload': payload}
    if not task:
      raise AssertionError("Task is not enqueued. expected: %r" % expected)
    self._post_task(task)

  def _post_task(self, task, **kwargs):
    # Not 100% sure why the Content-Length hack is needed, nor why the
    # stub returns unicode values that break webtest's assertions.
    body = base64.b64decode(task['body'])
    headers = {k: str(v) for k, v in task['headers']}
    headers['Content-Length'] = str(len(body))
    try:
      self.app.post(task['url'], body, headers=headers, **kwargs)
    except:
      logging.error(task)
      raise

  def _find_task(self, url, queue_name, payload):
    for t in self._taskqueue_stub.GetTasks(queue_name):
      if t['url'] != url:
        continue
      if t['queue_name'] != queue_name:
        continue
      if base64.b64decode(t['body']) != payload:
        continue
      return t
    return None


class Endpoints(object):
  """Handles endpoints API calls."""
  def __init__(self, api_service_cls, regex=None, source_ip='127.0.0.1'):
    super(Endpoints, self).__init__()
    self._api_service_cls = api_service_cls
    kwargs = {}
    if regex:
      kwargs['regex'] = regex
    self._api_app = webtest.TestApp(
        endpoints_webapp2.api_server([self._api_service_cls], **kwargs),
        extra_environ={'REMOTE_ADDR': source_ip})

  def call_api(self, method, body=None, status=(200, 204)):
    """Calls endpoints API method identified by its name."""
    # Because body is a dict and not a ResourceContainer, there's no way to tell
    # which parameters belong in the URL and which belong in the body when the
    # HTTP method supports both. However there's no harm in supplying parameters
    # in both the URL and the body since ResourceContainers don't allow the same
    # parameter name to be used in both places. Supplying parameters in both
    # places produces no ambiguity and extraneous parameters are safely ignored.
    assert hasattr(self._api_service_cls, method), method
    info = getattr(self._api_service_cls, method).method_info
    path = info.get_path(self._api_service_cls.api_info)

    # Identify which arguments are path parameters and which are query strings.
    body = body or {}
    query_strings = []
    for key, value in sorted(body.items()):
      if '{%s}' % key in path:
        path = path.replace('{%s}' % key, value)
      else:
        # We cannot tell if the parameter is a repeated field from a dict.
        # Allow all query strings to be multi-valued.
        if not isinstance(value, list):
          value = [value]
        for val in value:
          query_strings.append('%s=%s' % (key, val))
    if query_strings:
      path = '%s?%s' % (path, '&'.join(query_strings))

    path = '/_ah/api/%s/%s/%s' % (self._api_service_cls.api_info.name,
                                  self._api_service_cls.api_info.version,
                                  path)
    try:
      if info.http_method in ('GET', 'DELETE'):
        return self._api_app.get(path, status=status)
      return self._api_app.post_json(path, body, status=status)
    except Exception as e:
      # Useful for diagnosing issues in test cases.
      logging.info('%s failed: %s', path, e)
      raise


class EndpointsTestCase(TestCase):
  """Base class for a test case that tests Cloud Endpoint Service.

  Usage:
    class MyTestCase(test_case.EndpointsTestCase):
      api_service_cls = MyEndpointsService

      def test_stuff(self):
        response = self.call_api('my_method')
        self.assertEqual(...)

      def test_expected_fail(self):
        with self.call_should_fail(403):
          self.call_api('protected_method')
  """
  # Should be set in subclasses to a subclass of remote.Service.
  api_service_cls = None
  # Should be set in subclasses to a regular expression to match against path
  # parameters. See components.endpoints_webapp2.adapter.api_server.
  api_service_regex = None

  # See call_should_fail.
  expected_fail_status = None

  _endpoints = None

  def setUp(self):
    super(EndpointsTestCase, self).setUp()
    self._endpoints = Endpoints(
        self.api_service_cls, regex=self.api_service_regex)

  def call_api(self, method, body=None, status=(200, 204)):
    if self.expected_fail_status:
      status = self.expected_fail_status
    return self._endpoints.call_api(method, body, status)

  @contextlib.contextmanager
  def call_should_fail(self, status):
    """Asserts that Endpoints call inside the guarded region of code fails."""
    # TODO(vadimsh): Get rid of this function and just use
    # call_api(..., status=...). It existed as a workaround for bug that has
    # been fixed:
    # https://code.google.com/p/googleappengine/issues/detail?id=10544
    assert self.expected_fail_status is None, 'nested call_should_fail'
    assert status is not None
    self.expected_fail_status = int(status)
    try:
      yield
    except AssertionError:
      # Assertion can happen if tests are running on GAE < 1.9.31, where
      # endpoints bug still exists (and causes webapp guts to raise assertion).
      # It should be rare (since we are switching to GAE >= 1.9.31), so don't
      # bother to check that assertion was indeed raised. Just skip it if it
      # did.
      pass
    finally:
      self.expected_fail_status = None