# -*- coding:utf-8 -*- import sys import time import logging import json import pytest from tornado.testing import AsyncHTTPTestCase @pytest.fixture(scope="session") def tf_logs(tmpdir_factory): import numpy as np try: import tensorflow.compat.v1 as tf tf.disable_v2_behavior() except ImportError: import tensorflow as tf x = np.random.rand(5) y = 3 * x + 1 + 0.05 * np.random.rand(5) a = tf.Variable(0.1) b = tf.Variable(0.) err = a*x+b-y loss = tf.norm(err) tf.summary.scalar("loss", loss) tf.summary.scalar("a", a) tf.summary.scalar("b", b) merged = tf.summary.merge_all() optimizor = tf.train.GradientDescentOptimizer(0.01).minimize(loss) with tf.Session() as sess: log_dir = tmpdir_factory.mktemp("logs", numbered=False) log_dir = str(log_dir) train_write = tf.summary.FileWriter(log_dir, sess.graph) tf.global_variables_initializer().run() for i in range(1000): _, merged_ = sess.run([optimizor, merged]) train_write.add_summary(merged_, i) return log_dir @pytest.fixture(scope="session") def nb_app(): sys.argv = ["--port=6005", "--ip=127.0.0.1", "--no-browser", "--debug"] from notebook.notebookapp import NotebookApp app = NotebookApp() app.log_level = logging.DEBUG app.ip = '127.0.0.1' # TODO: Add auth check tests app.token = '' app.password = '' app.disable_check_xsrf = True app.initialize() return app.web_app class TestJupyterExtension(AsyncHTTPTestCase): @pytest.fixture(autouse=True) def init_jupyter(self, tf_logs, nb_app, tmpdir_factory): self.app = nb_app self.log_dir = tf_logs self.tmpdir_factory = tmpdir_factory def get_app(self): return self.app def test_tensorboard(self): content = {"logdir": self.log_dir} content_type = {"Content-Type": "application/json"} response = self.fetch( '/api/tensorboard', method='POST', body=json.dumps(content), headers=content_type) response = self.fetch('/api/tensorboard') instances = json.loads(response.body.decode()) assert len(instances) > 0 response = self.fetch('/api/tensorboard/1') instance = json.loads(response.body.decode()) instance2 = None for inst in instances: if inst["name"] == instance["name"]: instance2 = inst assert instance == instance2 response = self.fetch('/tensorboard/1/#graphs') assert response.code == 200 response = self.fetch('/tensorboard/1/data/plugins_listing') plugins_list = json.loads(response.body.decode()) assert plugins_list["graphs"] assert plugins_list["scalars"] response = self.fetch( '/api/tensorboard/1', method='DELETE') assert response.code == 204 response = self.fetch('/api/tensorboard/1') error_msg = json.loads(response.body.decode()) assert error_msg["message"].startswith( "TensorBoard instance not found:") def test_instance_reload(self): content = {"logdir": self.log_dir, "reload_interval": 4} content_type = {"Content-Type": "application/json"} response = self.fetch( '/api/tensorboard', method='POST', body=json.dumps(content), headers=content_type) instance = json.loads(response.body.decode()) assert instance is not None name = instance["name"] reload_time = instance["reload_time"] time.sleep(5) response = self.fetch('/api/tensorboard/{}'.format(name)) instance2 = json.loads(response.body.decode()) assert instance2["reload_time"] != reload_time