# Copyright 2018 Guanshuo Wang. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.ops import io_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops from tensorflow.python.training.saver import BaseSaverBuilder, _set_cpu0 import six from tensorflow.python.ops import variables from tensorflow.python.ops import state_ops class DataParallelSaverBuilder(BaseSaverBuilder): def __init__(self): super(DataParallelSaverBuilder, self).__init__() def save_op(self, filename_tensor, saveables): tensor_names = [] tensors = [] tensor_slices = [] for saveable in saveables: for spec in saveable.specs: if spec.name.startswith('replicated_'): if spec.name.startswith('replicated_0') or 'avg' in spec.name: tensor_names.append('/'.join(spec.name.split('/')[1:])) tensors.append(spec.tensor) tensor_slices.append(spec.slice_spec) else: tensor_names.append(spec.name) tensors.append(spec.tensor) tensor_slices.append(spec.slice_spec) if self._write_version == saver_pb2.SaverDef.V1: return io_ops._save( filename=filename_tensor, tensor_names=tensor_names, tensors=tensors, tensor_slices=tensor_slices) elif self._write_version == saver_pb2.SaverDef.V2: # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>". return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices, tensors) else: raise RuntimeError("Unexpected write_version: " + self._write_version) def restore_op(self, filename_tensor, saveable, preferred_shard): tensors = [] for spec in saveable.specs: # Ignore the moving_mean and moving_variance in other towers. if spec.name.startswith('replicated_'): if not spec.name.startswith('replicated_0') and 'BatchNorm/moving_' in spec.name: continue tensors.append( io_ops.restore_v2( filename_tensor, ['/'.join(spec.name.split('/')[1:])], [spec.slice_spec], [spec.tensor.dtype])[0]) else: tensors.append( io_ops.restore_v2( filename_tensor, [spec.name], [spec.slice_spec], [spec.tensor.dtype])[0]) return tensors def _AddRestoreOps(self, filename_tensor, saveables, restore_sequentially, reshape, preferred_shard=-1, name="restore_all"): assign_ops = [] for saveable in saveables: restore_control_inputs = assign_ops[-1:] if restore_sequentially else [] with ops.device(_set_cpu0(saveable.device) if saveable.device else None): with ops.control_dependencies(restore_control_inputs): tensors = self.restore_op(filename_tensor, saveable, preferred_shard) if len(tensors) == 0: continue shapes = None if reshape: shapes = [] for spec in saveable.specs: v = spec.tensor shape = v.get_shape() if not shape.is_fully_defined(): shape = array_ops.shape(v) shapes.append(shape) assign_ops.append(saveable.restore(tensors, shapes)) return control_flow_ops.group(*assign_ops, name=name) class DebugSaverBuilder(BaseSaverBuilder): def __init__(self): super(DebugSaverBuilder, self).__init__() def restore_op(self, filename_tensor, saveable, preferred_shard): tensors = [] for spec in saveable.specs: print(spec.name) tensors.append( io_ops.restore_v2( filename_tensor, [spec.name], [spec.slice_spec], [spec.tensor.dtype])[0]) return tensors