import traceback
import sys
import tensorflow as tf

from import graph_summary

from nao.compiler.retvalbag import RetvalBag

def eprint(*args, **kwargs):
  print(*args, file=sys.stderr, **kwargs)

class ReplSession:
  def __init__(self, compiler, log_dir_fn):
    self._log_dir_fn = log_dir_fn
    self._suffix = ".nao"
    self._compiler = compiler
    self._session = compiler.new_session()
    self._graph = self._session.graph
    self._previous_queue_runners = frozenset()
    self._previous_vars = frozenset()
    self._threads = []
    self._coord = tf.train.Coordinator()
    self._next_run_id = 0
    self._summary_writer = None

  def _vars(self):
    with self._graph.as_default():
      return tf.global_variables()

  def _queue_runners(self):
    with self._graph.as_default():
      return tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)

  def _init_new_vars(self, new_vars):
    if len(new_vars) == 0:

    print("New variables", new_vars)

  def _init_new_queue_runners(self, new_queue_runners):
    if len(new_queue_runners) == 0:

    print("new_queue_runners", new_queue_runners)
    for qr in new_queue_runners:
      threads = qr.create_threads(self._session, coord=self._coord, daemon=True, start=True)
      print("started", threads)

  def run(self, src, summary_fn=None):
    run_id = self._next_run_id
    self._next_run_id = self._next_run_id + 1
    if self._summary_writer is None:
      self._summary_writer = tf.summary.FileWriter(

      multiplexer = graph_summary.Multiplexer([self._summary_writer])
      if summary_fn is not None:

      return self._run(multiplexer, run_id, src)

  def _run(self, summary_writer, run_id, src):
    self._compiler.put_source("main%s" % self._suffix, src)

    above = None
    pkg = self._compiler.resolve_import_path("main", reimport=True)
    above = pkg.ctx().get_above()

    # Write graph once we've generated it.
    summary_writer.add_graph(self._session.graph, run_id)

    vars = frozenset(self._vars())
    self._init_new_vars(vars - self._previous_vars)
    self._previous_vars = vars

    queue_runners = frozenset(self._queue_runners())
    self._init_new_queue_runners(queue_runners - self._previous_queue_runners)
    self._previous_queue_runners = queue_runners

    if isinstance(above, RetvalBag):
      above = above.get(None)

    if isinstance(above, (tf.Tensor, tf.Variable, tf.Operation)):
      run_metadata = tf.RunMetadata()
      above =, run_metadata=run_metadata)
      summary_writer.add_run_metadata(run_metadata, "repl-%04d" % run_id, run_id)

    return above

  def __del__(self):
    # Shutdown threads, if any.
    if self._summary_writer is not None:


import atexit
import os
import readline

HISTORY_BASENAME = '.nao_history'

def run(parser, log_fn):
  histfile = os.path.join(os.path.expanduser("~"), HISTORY_BASENAME)

    h_len = readline.get_history_length()
  except FileNotFoundError:
    open(histfile, 'wb').close()
    h_len = 0

  def save(prev_h_len, histfile):
    new_h_len = readline.get_history_length()
    readline.append_history_file(new_h_len - prev_h_len, histfile)

  repl_session = ReplSession(parser, log_fn)
  while True:
      src = input("> ")
      if src == "":
    except KeyboardInterrupt:
    except EOFError:

      result =
    except Exception as e:
      print("".join(traceback.format_exception(None, e, e.__traceback__)),
            file=sys.stdout, flush=True)

