# Copyright (C) 2018 Seoul National University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import contextlib
from multiprocessing.managers import BaseManager
import os
import threading
import time

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow as print_mdl
from tensorflow.python.client import session
from tensorflow.python.util import compat

from parallax.core.python.common.partitions import *
COLLECT_STAT_START = 50
COLLECT_STAT_END = 100

def _parallax_init(self, target='', graph=None, config=None):
    """Overwrites the session.__init__."""
    self._init_internal(target, graph, config)  # pylint: disable=protected-access

def _parallax_run(self,
                  fetches,
                  feed_dict=None,
                  options=None,
                  run_metadata=None):

    global_step_op = self.parallax_session_context._convert_fetch(
        self.parallax_session_context._global_step_op)
    global_step = self._run_internal(global_step_op)[0]

    fetches = self.parallax_session_context._convert_fetch(fetches)
    feed_dict = self.parallax_session_context._convert_feed(feed_dict)

    if not self.parallax_session_context._send_exec_time and \
        (self.parallax_session_context._profile_dir is None 
        or (self.parallax_session_context._profile_steps is None \
            and self.parallax_session_context._profile_range is None)):
        return self._run_internal(fetches, feed_dict)

    with self.parallax_session_context._new_step() as state:
        step, locked = state
        if locked and self.parallax_session_context._send_exec_time:
             start_step = self.parallax_session_context._start_step
             relative_step = step - start_step
             if COLLECT_STAT_START <= relative_step and relative_step <= COLLECT_STAT_END:
                 start = time.time()
                 ret = self._run_internal(fetches, feed_dict)
                 end = time.time()
                 self.parallax_session_context._exec_time += (end - start)
                 if step == COLLECT_STAT_END:
                     host = self.parallax_session_context._master['hostname']
                     port = int(self.parallax_session_context._master['port'][0])
                     BaseManager.register('queue')
                     m = BaseManager(address=(host, port), authkey='parallax_auth')
                     m.connect()
                     queue = m.queue()
                     queue.put(self.parallax_session_context._exec_time)
             else:
                 ret = self._run_internal(fetches, feed_dict)
        elif locked and self.parallax_session_context._is_profile_step(global_step):
            if not run_metadata:
                run_metadata = self.parallax_session_context._run_metadata()
            if not options:
                options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                old_trace_level = options.trace_level
            else:
                old_trace_level = options.trace_level
                options.trace_level = tf.RunOptions.FULL_TRACE
            ret = self._run_internal(
                fetches, feed_dict, options, run_metadata)
            self.parallax_session_context._dump_profile(
                run_metadata, 'run_meta_%d' % global_step)
            options.trace_level = old_trace_level
        else:
            ret = self._run_internal(fetches, feed_dict)

    return ret          
                      
class ParallaxSessionContext(object):
    """A context that wraps session for Parallax.
       
    This class references tf.contrib.tfprof.ProfileContext class.
    """
   
    def __init__(self,
                 step,
                 global_step_op,
                 profile_dir,
                 profile_steps,
                 profile_range,
                 replica_dict,
                 num_replicas_per_worker,
                 master):
        """Constructs an `ParallaxSessionContext` instance.

        Args:
          profile_dir: Directory to store profiles.
          profile_steps: A list of steps for tracing and saving as a file.
          profile_range : A tuple of tracing start and end step.
          replica_dict : A dictionary to map old tensor(operation) name
            to new tensor(operation) names.
          num_replicas_per_worker : Number of replicas per worker.
          master: Stat collection master for partitioning.
        """
        self._start_step = step
        self._step = step
        self._global_step_op = global_step_op
        self._profile_dir = profile_dir
        self._profile_steps = profile_steps
        self._profile_range = profile_range
        assert self._profile_steps is None or self._profile_range is None
        self._replica_dict = replica_dict
        self._num_replicas_per_worker = num_replicas_per_worker
        self._run_metadata = None
        self._send_exec_time = os.environ[PARALLAX_SEARCH] == 'True'
        self._exec_time = 0
        self._master = master

        for key, values in self._replica_dict.items():
            if len(values) == 1:
                item = values[0]
                self._replica_dict[key] = [item for _ in
                    range(self._num_replicas_per_worker)]
        self._lock = threading.Lock()
    
    @contextlib.contextmanager
    def _new_step(self):
        acquired = self._lock.acquire(False)
        yield (self._step, acquired)
        self._step += 1
        if acquired:
            self._lock.release()
 
    def _is_profile_step(self, step):
      if self._profile_steps and step in self._profile_steps:
          return True
      elif self._profile_range and \
          (step >= self._profile_range[0] and step < self._profile_range[1]):
          return True
      return False

    def _run_metadata(self):
      if not self._run_metadata:
          self._run_metadata = tf.RunMetadata()
      return self._run_metadata

    def _dump_profile(self, metadata, basename):
      if not tf.gfile.Exists(self._profile_dir):
          tf.gfile.MakeDirs(self._profile_dir)
      with tf.gfile.Open(os.path.join(self._profile_dir, basename), 'wb') as f:
          f.write(metadata.SerializeToString())
      self._run_metadata = None

    def _read_converted_names(self, target):
        if isinstance(target, compat.bytes_or_text_types):
            target_name = target
        else:
            target_name = target.name
        if target_name in self._replica_dict:
            return self._replica_dict[target_name]
        else:
            return target
     
    def _convert_fetch(self, fetch):
        if fetch is None:
            raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
                                                                 type(fetch)))
        elif isinstance(fetch, (list, tuple)):
            return [self._convert_fetch(f) for f in fetch]
        elif isinstance(fetch, dict):
            keys = list(fetch.keys())
            values = [self._convert_fetch(f) for f in fetch.values()]
            return dict(zip(keys, values))
        else:
            if isinstance(fetch, tf.SparseTensor):
                return [tf.SparseTensor(self._replica_dict[fetch.indices][i],
                                        self._replica_dict[fetch.values][i],
                                        self._replica_dict[fetch.dense_shape][i]) 
                           for i in range(self._num_replicas_per_worker)]
            elif isinstance(fetch, tf.IndexedSlices):
                return [tf.IndexedSlices(
                           self._replica_dict[fetch.values][i],
                           self._replica_dict[fetch.indices][i],
                           None if fetch.dense_shape is None \
                                else self._replica_dict[fetch.dense_shape][i]) 
                               for i in range(self._num_replicas_per_worker)]
            else:
                return self._read_converted_names(fetch)

    def _convert_feed(self, feed_dict):

        def _feed_fn(feed):
            for tensor_type, _, _, feed_fn in session._REGISTERED_EXPANSIONS:
                if isinstance(feed, tensor_type):
                    return feed_fn(feed)
            raise TypeError('Feed argument %r has invalid type %r' % (feed,
                                                                   type(feed)))
        if feed_dict:
            new_feed_dict = {}
            for feed, feed_val in feed_dict.items():
                if isinstance(feed, compat.bytes_or_text_types):
                    new_feeds = self._read_converted_names(feed)
                    if isinstance(new_feeds, list):
                        for i in range(self._num_replicas_per_worker):
                            new_feed_dict[new_feeds[i]] = feed_val[i]
                    else:
                        new_feed_dict[new_feeds] = feed_val
                else:
                    for subfeed in _feed_fn(feed):
                        new_subfeeds = self._read_converted_names(subfeed)
                        if isinstance(new_subfeeds, list):
                            for i in range(self._num_replicas_per_worker):
                                new_feed_dict[new_subfeeds[i]] = feed_val[i]
                        else:
                            new_feed_dict[new_subfeeds] = feed_val
            return new_feed_dict
        else:
            return feed_dict
   
    def set_parallax_session_context(self):
      self.old_run = getattr(session.BaseSession, 'run', None)
      self.old_init = getattr(session.BaseSession, '__init__', None)
      if not self.old_run:
        raise tf.errors.InternalError(None, None, 'BaseSession misses run method.')
      elif not self.old_init:
        raise tf.errors.InternalError(None, None,
                                   'BaseSession misses __init__ method.')
      elif getattr(session.BaseSession, '_run_internal', None):
        raise tf.errors.InternalError(None, None,
                                   'Already in context or context not cleaned.')
      elif getattr(session.BaseSession, '_init_internal', None):
        raise tf.errors.InternalError(None, None,
                                   'Already in context or context not cleaned.')
      else:
        setattr(session.BaseSession, 'run', _parallax_run)
        setattr(session.BaseSession, '__init__', _parallax_init)
        setattr(session.BaseSession, '_run_internal', self.old_run)
        setattr(session.BaseSession, '_init_internal', self.old_init)
        setattr(session.BaseSession, 'parallax_session_context', self)
        return self