from abc import ABCMeta, abstractmethod import glob import os import tarfile from tempfile import TemporaryDirectory import tensorflow as tf def affine(input_tensor, output_size, bias=True, bias_start=0.0, input_size=None, scope="affine", sparse_input=False): """Add an affine transformation of `input_tensor` to the current graph. Note: This op is loosely based on tensorflow.python.ops.rnn_cell.linear. An affine transformation is a linear transformation with a shift, `t = tf.matmul(input_tensor, W) + b`. Parameters ---------- input_tensor : tensorflow Tensor object, rank 2 Input tensor to be transformed. output_size : int The output will be size [a, output_size] where `input_tensor` has shape [a, b]. bias : bool, optional If True, apply a bias to the transformation. If False, only a linear transformation is applied (i.e., `t = tf.matmul(W, input_tensor)`). bias_start : float, optional The initial value for the bias `b`. input_size : int, optional Second dimension of the rank 2 input tensor. Required for sparse input tensors. sparse_input : bool, optional Set to True if `input_tensor` is sparse. Returns ------- t : tensorflow tensor object The affine transformation of `input_tensor`. """ # The input size is needed for sparse matrices. if input_size is None: input_size = input_tensor.get_shape().as_list()[1] with tf.variable_scope(scope): W_0 = tf.get_variable( "weights0", [input_size, output_size]) # If the input is sparse, then use a special matmul routine. matmul = tf.sparse_tensor_dense_matmul if sparse_input else tf.matmul t = matmul(input_tensor, W_0) if bias: b_0 = tf.get_variable( "bias0", [output_size], initializer=tf.constant_initializer(bias_start)) t = tf.add(t, b_0) return t class TFPicklingBase(object, metaclass=ABCMeta): """Base class for pickling TensorFlow-based scikit-learn estimators. This base class defines a few standard attributes to enable fairly transparent pickling of TensorFlow models. Note that TensorFlow has a custom saving mechanism that makes pickling (and thus using it in scikit-learn, etc.) not straightforward. NOTE: This base class must come first in the list of classes any child class inherits from. When pickling an object, if the `self._is_fitted` property is True: 1. The session at `self._session` is saved using the saver at `self._saver` to a temporary file. 2. The saved data is then read into memory and attached to the object state at '_saved_model'. 3. The fitted state of the model is saved at '_fitted' as True. When unpickling the object: 1. All variables in the state of the object are set using `self.__dict__` except the '_saved_model' entry. 2. If the '_fitted' key is in the state of the object and is True 2a. The '_saved_model' entry is written to a temporary file. 2b. A new TF graph is instantiated at `self.graph_`. 2c. `self._build_tf_graph()`` is called. This instantiates a `tf.Saver` at `self._saver` and a `tf.Session` at `self._session`. 2d. The `self._saver` is used to restore previous session to the current one. To use this base class properly, the child class needs to 1. Implement the abstract method `self._set_up_graph`. This method should build the required TF graph. 2. Exactly once (e.g., in the `fit` method), instantiate a `tf.Graph` at `self.graph_` and then call `self._build_tf_graph` inside the `tf.Graph` context block. `self._build_tf_graph` will call `self._set_up_graph` and further instantiate the `tf.Saver` and `tf.Session`. 3. After 2. is done, set `self._is_fitted = True`. 4. Make sure override `__getstate__` to store any extra information about your estimator to the state of the object. When doing this, call `state = super().__getstate__()` and then append to the `state`. See the example below and also the MLP classes and base class, MLPBaseEstimator. Example ------- ```python # example class for using TFPicklingBase - adds a scalar to input 1d # arrays class TFAdder(TFPicklingBase): def __init__(self, add_val): # real scikit-learn estimators should do all of this work in the # fit method self.add_val = float(add_val) self.graph_ = tf.Graph() with self.graph_.as_default(): self._build_tf_graph() self._session.run(tf.initialize_all_variables()) self._is_fitted = True def _set_up_graph(self): self._a = tf.placeholder(tf.float32, shape=[None], name='a') self._add_val = tf.Variable(self.add_val, name='add_val', dtype=tf.float32) self._sum = tf.add(self._a, self._add_val, name='sum') def add(self, a): with self.graph_.as_default(): val = self._session.run(self._sum, feed_dict={self._a: a}) return val def __getstate__(self): state = super().__getstate__() # add add_val to state state['add_val'] = self.add_val return state ``` """ @property def _is_fitted(self): """Return True if the model has been at least partially fitted. Returns ------- bool Notes ----- This is to indicate whether, e.g., the TensorFlow graph for the model has been created. """ return getattr(self, '_fitted', False) @_is_fitted.setter def _is_fitted(self, b): """Set whether the model has been at least partially fitted. Parameters ---------- b : bool True if the model has been fitted. """ self._fitted = b def __getstate__(self): # Override __getstate__ so that TF model parameters are pickled # properly. if self._is_fitted: with TemporaryDirectory() as tmpdir: # Serialize the model. self._saver.save( self._session, os.path.join(tmpdir, 'saved_model')) # TF writes a bunch of files so tar them. fnames = glob.glob(os.path.join(tmpdir, '*')) tarname = os.path.join(tmpdir, 'saved_model.tar') with tarfile.open(tarname, "w") as tar: for f in fnames: tar.add(f, arcname=os.path.split(f)[-1]) # Now read the state back into memory. with open(tarname, 'rb') as f: saved_model = f.read() # Note: don't include the graph since it should be recreated. state = {} # Add fitted attributes if the model has been fitted. if self._is_fitted: state['_fitted'] = True state['_saved_model'] = saved_model return state def __setstate__(self, state): # Override __setstate__ so that TF model parameters are unpickled # properly. for k, v in state.items(): if k != '_saved_model': self.__dict__[k] = v if state.get('_fitted', False): with TemporaryDirectory() as tmpdir: # Write out the serialized tarfile. tarname = os.path.join(tmpdir, 'saved_model.tar') with open(tarname, 'wb') as f: f.write(state['_saved_model']) # Untar it. with tarfile.open(tarname, 'r') as tar: tar.extractall(path=tmpdir) # And restore. self.graph_ = tf.Graph() with self.graph_.as_default(): self._build_tf_graph() self._saver.restore( self._session, os.path.join(tmpdir, 'saved_model')) def _build_tf_graph(self): """Build the TF graph, setup model saving and setup a TF session. Notes ----- This method initializes a TF Saver and a TF Session via ```python self._saver = tf.train.Saver() self._session = tf.Session() ``` These calls are made after `self._set_up_graph()`` is called. See the main class docs for how to properly call this method from a child class. """ self._set_up_graph() self._saver = tf.train.Saver() self._session = tf.Session() @abstractmethod def _set_up_graph(self): """Assemble the TF graph for estimator. Notes ----- Child classes should add the TF ops to the graph they want to implement here. """ pass