#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import print_function from __future__ import division from itertools import permutations import tensorflow as tf from zhusuan.framework.utils import * from zhusuan.framework.utils import Context class TestContext(tf.test.TestCase): def test_Context(self): self.assertEqual(Context.get_contexts(), []) with self.assertRaisesRegexp(RuntimeError, "No contexts on the stack"): Context.get_context() with Context() as context: self.assertEqual(Context.get_contexts(), [context]) self.assertEqual(Context.get_context(), context) with Context() as context_inner: self.assertEqual(Context.get_contexts(), [context, context_inner]) self.assertEqual(Context.get_context(), context_inner) self.assertEqual(Context.get_contexts(), [context]) self.assertEqual(Context.get_context(), context) self.assertEqual(Context.get_contexts(), []) with self.assertRaisesRegexp(RuntimeError, "No contexts on the stack"): Context.get_context() class TestGetBackwardTensors(tf.test.TestCase): def testGetBackwardOpsChain(self): # a -> b -> c a = tf.placeholder(tf.float32) b = tf.sqrt(a) c = tf.square(b) for n in range(4): for seed_tensors in permutations([a, b, c], n): if c in seed_tensors: truth = [a.op, b.op, c.op] elif b in seed_tensors: truth = [a.op, b.op] elif a in seed_tensors: truth = [a.op] else: truth = [] self.assertEqual(get_backward_ops(seed_tensors), truth) self.assertEqual(get_backward_ops([c], treat_as_inputs=[b]), [c.op]) self.assertEqual( get_backward_ops([b, c], treat_as_inputs=[b]), [c.op]) self.assertEqual( get_backward_ops([a, c], treat_as_inputs=[b]), [a.op, c.op]) def testGetBackwardOpsSplit(self): # a -> b -> c # \-> d a = tf.placeholder(tf.float32) b = tf.exp(a) c = tf.log(b) d = tf.negative(b) self.assertEqual(get_backward_ops([d]), [a.op, b.op, d.op]) self.assertEqual(get_backward_ops([c]), [a.op, b.op, c.op]) self.assertEqual( get_backward_ops([c, d]), [a.op, b.op, c.op, d.op]) self.assertEqual(get_backward_ops([b, d]), [a.op, b.op, d.op]) self.assertEqual(get_backward_ops([a, d]), [a.op, b.op, d.op]) self.assertEqual( get_backward_ops([c, d], treat_as_inputs=[b]), [c.op, d.op]) self.assertEqual( get_backward_ops([c], treat_as_inputs=[d]), [a.op, b.op, c.op]) def testGetBackwardOpsMerge(self): # a -> c -> d # b ->/ a = tf.placeholder(tf.float32) b = tf.constant(0, dtype=tf.int32) c = tf.reduce_sum(a, reduction_indices=b) d = tf.stop_gradient(c) self.assertEqual( get_backward_ops([d]), [a.op, b.op, c.op, d.op]) self.assertEqual(get_backward_ops([d], treat_as_inputs=[c]), [d.op]) self.assertEqual( get_backward_ops([d], treat_as_inputs=[a]), [b.op, c.op, d.op]) def testGetBackwardOpsBridge(self): # a -> b -> c -> d -> e # \ --- / a = tf.placeholder(tf.int32) b = tf.identity(a) c = tf.cast(b, tf.float32) d = tf.tile(c, b) e = tf.tanh(d) self.assertEqual( get_backward_ops([e]), [a.op, b.op, c.op, d.op, e.op]) self.assertEqual(get_backward_ops([c]), [a.op, b.op, c.op]) self.assertEqual(get_backward_ops([e], treat_as_inputs=[c]), [a.op, b.op, d.op, e.op]) def testGetBackwardOpsControlDeps(self): # a -> b - \ # c -> d - e # \ / # f a = tf.placeholder(tf.float32, name='a') b = tf.identity(a, name='b') c = tf.placeholder(tf.float32, name='c') d = tf.identity(c, name='d') with tf.control_dependencies([b, d]): e = tf.placeholder(tf.float32, name='e') with tf.control_dependencies([e, d]): f = tf.placeholder(tf.float32, name='f') self.assertEqual(get_backward_ops([f]), [a.op, b.op, c.op, d.op, e.op, f.op]) self.assertEqual(get_backward_ops([d, f]), [c.op, d.op, a.op, b.op, e.op, f.op]) self.assertEqual(get_backward_ops([f], treat_as_inputs=[b]), [a.op, b.op, c.op, d.op, e.op, f.op]) self.assertEqual(get_backward_ops([f], treat_as_inputs=[b, c]), [a.op, b.op, d.op, e.op, f.op]) self.assertEqual(get_backward_ops([f], treat_as_inputs=[d, e]), [a.op, b.op, c.op, d.op, e.op, f.op]) self.assertEqual(get_backward_ops([d, f], treat_as_inputs=[b]), [c.op, d.op, a.op, b.op, e.op, f.op]) def test_get_backward_ops_control_flow(self): # while_loop, scan, TensorArray pass