Python tensorflow.keras() Examples

The following are 30 code examples of tensorflow.keras(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: clustering_registry_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testMakeClusterableWorksOnKerasRNNLayer(self):
    """
    Verifies that make_clusterable() works as expected on a built-in
    RNN layer.
    """
    layer = layers.LSTM(10)
    with self.assertRaises(AttributeError):
      layer.get_clusterable_weights()

    ClusterRegistry.make_clusterable(layer)
    keras.Sequential([layer]).build(input_shape=(2, 3, 4))

    expected_weights = [
        ('kernel', layer.cell.kernel),
        ('recurrent_kernel', layer.cell.recurrent_kernel)
    ]
    self.assertEqual(expected_weights, layer.get_clusterable_weights()) 
Example #2
Source File: pwl_calibration_test.py    From lattice with Apache License 2.0 6 votes vote down vote up
def testConvexityNonUniformKeypoints(self, units, convexity, expected_loss):
    # No constraints other than convexity.
    if self._disable_all:
      return

    config = {
        "units": units,
        "num_training_records": 100,
        "num_training_epoch": 200,
        "optimizer": tf.keras.optimizers.Adagrad,
        "learning_rate": 1.0,
        "x_generator": self._ScatterXUniformly,
        "y_function": self._WavyParabola,
        "monotonicity": 0,
        "convexity": convexity,
        "input_keypoints": [-1.0, -0.9, -0.3, -0.2, 0.0, 0.3, 0.31, 0.35, 1.0],
        "output_min": None,
        "output_max": None,
    }
    loss = self._TrainModel(config)
    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)
    if units > 1:
      config["use_multi_calibration_layer"] = True
      loss = self._TrainModel(config)
      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps) 
Example #3
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testClusterSequentialModelPreservesBuiltStateNoInput(self):
    """
    Verifies that clustering a sequential model without an input layer
    preserves the built state of the model.
    """
    # No InputLayer
    model = keras.Sequential([
        layers.Dense(10),
        layers.Dense(10),
    ])
    self.assertEqual(model.built, False)
    clustered_model = cluster.cluster_weights(model, **self.params)
    self.assertEqual(model.built, False)

    # Test built state is preserved across serialization
    with cluster.cluster_scope():
      loaded_model = keras.models.model_from_config(
          json.loads(clustered_model.to_json()))
      self.assertEqual(loaded_model.built, False) 
Example #4
Source File: pwl_calibration_test.py    From lattice with Apache License 2.0 6 votes vote down vote up
def testInputKeypoints(self, keypoints):
    if self._disable_all:
      return
    config = {
        "num_training_records": 100,
        "num_training_epoch": 200,
        "optimizer": tf.keras.optimizers.Adagrad,
        "learning_rate": 0.15,
        "x_generator": self._ScatterXUniformly,
        "y_function": self._SmallWaves,
        "monotonicity": 0,
        "input_keypoints": keypoints,
        "output_min": None,
        "output_max": None,
    }
    loss = self._TrainModel(config)
    self.assertAlmostEqual(loss, 0.009650, delta=self._loss_eps) 
Example #5
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testClusterWeightsStrippedWeights(self):
    """
    Verifies that stripping the clustering wrappers from a functional model
    preserves the clustered weights.
    """
    i1 = keras.Input(shape=(10,))
    x1 = layers.BatchNormalization()(i1)
    outputs = x1
    model = keras.Model(inputs=[i1], outputs=outputs)

    clustered_model = cluster.cluster_weights(model, **self.params)
    cluster_weight_length = (len(clustered_model.get_weights()))
    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertEqual(len(stripped_model.get_weights()), cluster_weight_length) 
Example #6
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testStrippedKernel(self):
    """
    Verifies that stripping the clustering wrappers from a functional model
    restores the layers kernel and the layers weight array to the new clustered weight value .
    """
    i1 = keras.Input(shape=(1, 1, 1))
    x1 = layers.Conv2D(1, 1)(i1)
    outputs = x1
    model = keras.Model(inputs=[i1], outputs=outputs)

    clustered_model = cluster.cluster_weights(model, **self.params)
    clustered_conv2d_layer = clustered_model.layers[1]
    clustered_kernel = clustered_conv2d_layer.layer.kernel
    stripped_model = cluster.strip_clustering(clustered_model)
    stripped_conv2d_layer = stripped_model.layers[1]

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
    self.assertEqual(stripped_conv2d_layer.kernel,
                     stripped_conv2d_layer.weights[0]) 
Example #7
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testClusterFunctionalModelPreservesBuiltState(self):
    """
    Verifies that clustering a functional model preserves the built state of
    the model.
    """
    i1 = keras.Input(shape=(10,))
    i2 = keras.Input(shape=(10,))
    x1 = layers.Dense(10)(i1)
    x2 = layers.Dense(10)(i2)
    outputs = layers.Add()([x1, x2])
    model = keras.Model(inputs=[i1, i2], outputs=outputs)
    self.assertEqual(model.built, True)
    clustered_model = cluster.cluster_weights(model, **self.params)
    self.assertEqual(model.built, True)

    # Test built state preserves across serialization
    with cluster.cluster_scope():
      loaded_model = keras.models.model_from_config(
          json.loads(clustered_model.to_json()))
    self.assertEqual(loaded_model.built, True) 
Example #8
Source File: pruning_wrapper.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def from_config(cls, config):
    config = config.copy()

    pruning_schedule = config.pop('pruning_schedule')
    deserialize_keras_object = keras.utils.deserialize_keras_object  # pylint: disable=g-import-not-at-top
    # TODO(pulkitb): This should ideally be fetched from pruning_schedule,
    # which should maintain a list of all the pruning_schedules.
    custom_objects = {
        'ConstantSparsity': pruning_sched.ConstantSparsity,
        'PolynomialDecay': pruning_sched.PolynomialDecay
    }
    config['pruning_schedule'] = deserialize_keras_object(
        pruning_schedule,
        module_objects=globals(),
        custom_objects=custom_objects)

    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    layer = deserialize_layer(config.pop('layer'))
    config['layer'] = layer

    return cls(**config) 
Example #9
Source File: prune_integration_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def _train_model(model, epochs=1, x_train=None, y_train=None, callbacks=None):
    if x_train is None:
      x_train = np.random.rand(20, 10),
    if y_train is None:
      y_train = keras.utils.to_categorical(
          np.random.randint(5, size=(20, 1)), 5)

    if model.optimizer is None:
      model.compile(
          loss='categorical_crossentropy',
          optimizer='sgd',
          metrics=['accuracy'])

    if callbacks is None:
      callbacks = []
      if PruneIntegrationTest._is_pruned(model):
        callbacks = [pruning_callbacks.UpdatePruningStep()]

    model.fit(
        x_train, y_train, epochs=epochs, batch_size=20, callbacks=callbacks) 
Example #10
Source File: pwl_calibration_test.py    From lattice with Apache License 2.0 6 votes vote down vote up
def testUnconstrainedNoMissingValue(self, units, one_d_input, expected_loss):
    if self._disable_all:
      return
    config = {
        "units": units,
        "one_d_input": one_d_input,
        "num_training_records": 100,
        "num_training_epoch": 2000,
        "optimizer": tf.keras.optimizers.Adagrad,
        "learning_rate": 0.15,
        "x_generator": self._ScatterXUniformly,
        "y_function": self._SmallWaves,
        "monotonicity": 0,
        "num_keypoints": 21,
        "input_min": -1.0,
        "input_max": 1.0,
        "output_min": None,
        "output_max": None,
    }
    loss = self._TrainModel(config)
    self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps)
    if units > 1 and not one_d_input:
      config["use_multi_calibration_layer"] = True
      loss = self._TrainModel(config)
      self.assertAlmostEqual(loss, expected_loss, delta=self._loss_eps) 
Example #11
Source File: prune_integration_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
    model = keras.Sequential()
    args, input_shape = self._get_params_for_layer(layer_type)
    if args is None:
      return  # Test for layer not supported yet.
    model.add(prune.prune_low_magnitude(
        layer_type(*args), input_shape=input_shape, **self.params))

    model.compile(
        loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    test_utils.assert_model_sparsity(self, 0.0, model)
    model.fit(
        np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)),
        np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)),
        callbacks=[pruning_callbacks.UpdatePruningStep()])

    test_utils.assert_model_sparsity(self, 0.5, model)

    self._check_strip_pruning_matches_original(model, 0.5) 
Example #12
Source File: prune_integration_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testRNNLayersSingleCell_ReachesTargetSparsity(self, layer_type):
    model = keras.Sequential()
    model.add(
        prune.prune_low_magnitude(
            layer_type(10), input_shape=(3, 4), **self.params))

    model.compile(
        loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    test_utils.assert_model_sparsity(self, 0.0, model)
    model.fit(
        np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)),
        np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)),
        callbacks=[pruning_callbacks.UpdatePruningStep()])

    test_utils.assert_model_sparsity(self, 0.5, model)

    self._check_strip_pruning_matches_original(model, 0.5) 
Example #13
Source File: prune_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testPruneFunctionalModelPreservesBuiltState(self):
    i1 = keras.Input(shape=(10,))
    i2 = keras.Input(shape=(10,))
    x1 = layers.Dense(10)(i1)
    x2 = layers.Dense(10)(i2)
    outputs = layers.Add()([x1, x2])
    model = keras.Model(inputs=[i1, i2], outputs=outputs)
    self.assertEqual(model.built, True)
    pruned_model = prune.prune_low_magnitude(model, **self.params)
    self.assertEqual(model.built, True)

    # Test built state preserves across serialization
    with prune.prune_scope():
      loaded_model = keras.models.model_from_config(
          json.loads(pruned_model.to_json()))
    self.assertEqual(loaded_model.built, True) 
Example #14
Source File: layers.py    From astroNN with MIT License 6 votes vote down vote up
def __call__(self, model):
        """
        :param model: Keras model to be accelerated
        :type model: Union[keras.Model, keras.Sequential]
        :return: Accelerated Keras model
        :rtype: Union[keras.Model, keras.Sequential]
        """
        if isinstance(model, tfk.Model) or isinstance(model, tfk.Sequential):
            self.model = model
        else:
            raise TypeError(f'FastMCInference expects tensorflow.keras Model, you gave {type(model)}')
        new_input = tfk.layers.Input(shape=(self.model.input_shape[1:]), name='input')
        mc_model = tfk.models.Model(inputs=self.model.inputs, outputs=self.model.outputs)

        mc = FastMCInferenceMeanVar()(tfk.layers.TimeDistributed(mc_model)(FastMCRepeat(self.n)(new_input)))
        new_mc_model = tfk.models.Model(inputs=new_input, outputs=mc)

        return new_mc_model 
Example #15
Source File: prune_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testPruneInferenceWorks_PruningStepCallbackNotRequired(self):
    model = prune.prune_low_magnitude(
        keras.Sequential([
            layers.Dense(10, activation='relu', input_shape=(100,)),
            layers.Dense(2, activation='sigmoid')
        ]), **self.params)

    model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer=keras.optimizers.SGD(),
        metrics=['accuracy'])

    model.predict(np.random.rand(1000, 100))
    model.evaluate(
        np.random.rand(1000, 100),
        keras.utils.to_categorical(np.random.randint(2, size=(1000, 1)))) 
Example #16
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testStripSelectivelyClusteredFunctionalModel(self):
    """
    Verifies that invoking strip_clustering() on a selectively clustered
    functional model strips the clustering wrappers from the clustered layers.
    """
    i1 = keras.Input(shape=(10,))
    i2 = keras.Input(shape=(10,))
    x1 = cluster.cluster_weights(layers.Dense(10), **self.params)(i1)
    x2 = layers.Dense(10)(i2)
    outputs = layers.Add()([x1, x2])
    clustered_model = keras.Model(inputs=[i1, i2], outputs=outputs)

    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertIsInstance(stripped_model.layers[2], layers.Dense) 
Example #17
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testClusterModelValidLayersSuccessful(self):
    """
    Verifies that clustering a sequential model results in all clusterable
    layers within the model being clustered.
    """
    model = keras.Sequential([
        self.keras_clusterable_layer,
        self.keras_non_clusterable_layer,
        self.custom_clusterable_layer
    ])
    clustered_model = cluster.cluster_weights(model, **self.params)
    clustered_model.build(input_shape=(1, 28, 28, 1))

    self.assertEqual(len(model.layers), len(clustered_model.layers))
    for layer, clustered_layer in zip(model.layers, clustered_model.layers):
      self._validate_clustered_layer(layer, clustered_layer) 
Example #18
Source File: utils.py    From keras-adamw with MIT License 6 votes vote down vote up
def _update_t_cur_eta_t_v2(self, lr_t=None, var=None):  # tf.keras
    t_cur_update, eta_t_update = None, None  # in case not assigned

    # update `t_cur` if iterating last `(grad, var)`
    iteration_done = self._updates_processed == (self._updates_per_iter - 1)
    if iteration_done:
        t_cur_update = state_ops.assign_add(self.t_cur, 1,
                                            use_locking=self._use_locking)
        self._updates_processed = 0  # reset
    else:
        self._updates_processed += 1

    # Cosine annealing
    if self.use_cosine_annealing and iteration_done:
        # ensure eta_t is updated AFTER t_cur
        with ops.control_dependencies([t_cur_update]):
            eta_t_update = state_ops.assign(self.eta_t, _compute_eta_t(self),
                                            use_locking=self._use_locking)
        self.lr_t = lr_t * self.eta_t  # for external tracking

    return iteration_done, t_cur_update, eta_t_update 
Example #19
Source File: cluster_wrapper_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testValuesAreClusteredAfterStripping(self,
                                           number_of_clusters,
                                           cluster_centroids_init):
    """
    Verifies that, for any number of clusters and any centroid initialization
    method, the number of unique weight values after stripping is always less
    or equal to number_of_clusters.
    """
    original_model = tf.keras.Sequential([
        layers.Dense(32, input_shape=(10,)),
    ])
    clustered_model = cluster.cluster_weights(
        original_model,
        number_of_clusters=number_of_clusters,
        cluster_centroids_init=cluster_centroids_init
    )
    stripped_model = cluster.strip_clustering(clustered_model)
    weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
    unique_weights = set(weights_as_list)
    # Make sure numbers match
    self.assertLessEqual(len(unique_weights), number_of_clusters)

    # Make sure that the stripped layer is the Dense one
    self.assertIsInstance(stripped_model.layers[0], layers.Dense) 
Example #20
Source File: mnist_cnn.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def build_layerwise_model(input_shape, **pruning_params):
  return tf.keras.Sequential([
      prune.prune_low_magnitude(
          l.Conv2D(32, 5, padding='same', activation='relu'),
          input_shape=input_shape,
          **pruning_params),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.BatchNormalization(),
      prune.prune_low_magnitude(
          l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Flatten(),
      prune.prune_low_magnitude(
          l.Dense(1024, activation='relu'), **pruning_params),
      l.Dropout(0.4),
      prune.prune_low_magnitude(
          l.Dense(num_classes, activation='softmax'), **pruning_params)
  ]) 
Example #21
Source File: cluster_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testClusterModelDoesNotWrapAlreadyWrappedLayer(self):
    """
    Verifies that clustering a model that contains an already clustered layer
    does not result in wrapping the clustered layer into another
    cluster_wrapper.
    """
    model = keras.Sequential(
        [
            layers.Flatten(),
            cluster.cluster_weights(layers.Dense(10), **self.params),
        ])
    clustered_model = cluster.cluster_weights(model, **self.params)
    clustered_model.build(input_shape=(10, 10, 1))

    self.assertEqual(len(model.layers), len(clustered_model.layers))
    self._validate_clustered_layer(model.layers[0], clustered_model.layers[0])
    # Second layer is used as-is since it's already a clustered layer.
    self.assertEqual(model.layers[1], clustered_model.layers[1])
    self._validate_clustered_layer(model.layers[1].layer,
                                   clustered_model.layers[1]) 
Example #22
Source File: mnist_e2e.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def build_layerwise_model(input_shape, **pruning_params):
  return tf.keras.Sequential([
      l.Conv2D(
          32, 5, padding='same', activation='relu', input_shape=input_shape),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Conv2D(64, 5, padding='same'),
      l.BatchNormalization(),
      l.ReLU(),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Flatten(),
      prune.prune_low_magnitude(
          l.Dense(1024, activation='relu'), **pruning_params),
      l.Dropout(0.4),
      prune.prune_low_magnitude(
          l.Dense(num_classes, activation='softmax'), **pruning_params)
  ]) 
Example #23
Source File: clustering_registry_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testMakeClusterableWorksOnKerasRNNLayerWithClusterableCell(self):
    """
    Verifies that make_clusterable() works as expected on a built-in
    RNN layer with a custom clusterable RNN cell.
    """
    cell1 = layers.LSTMCell(10)
    cell2 = ClusterRegistryTest.MinimalRNNCellClusterable(5)
    layer = layers.RNN([cell1, cell2])
    with self.assertRaises(AttributeError):
      layer.get_clusterable_weights()

    ClusterRegistry.make_clusterable(layer)
    keras.Sequential([layer]).build(input_shape=(2, 3, 4))

    expected_weights = [
        ('kernel', cell1.kernel),
        ('recurrent_kernel', cell1.recurrent_kernel),
        ('kernel', cell2.kernel),
        ('recurrent_kernel', cell2.recurrent_kernel)
    ]
    self.assertEqual(expected_weights, layer.get_clusterable_weights()) 
Example #24
Source File: cluster_integration_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testValuesRemainClusteredAfterTraining(self):
    """Verifies that training a clustered model does not destroy the clusters."""
    original_model = keras.Sequential([
        layers.Dense(2, input_shape=(2,)),
        layers.Dense(2),
    ])

    clustered_model = cluster.cluster_weights(original_model, **self.params)

    clustered_model.compile(
        loss=keras.losses.categorical_crossentropy,
        optimizer="adam",
        metrics=["accuracy"],
    )

    clustered_model.fit(x=self.dataset_generator(), steps_per_epoch=1)
    stripped_model = cluster.strip_clustering(clustered_model)
    weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
    unique_weights = set(weights_as_list)
    self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"]) 
Example #25
Source File: clustering_registry_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testMakeClusterableWorksOnKerasRNNLayerWithRNNCellsParams(self):
    """
    Verifies that make_clusterable() works as expected on a built-in
    RNN layer with built-in RNN cells.
    """
    cell1 = layers.LSTMCell(10)
    cell2 = layers.GRUCell(5)
    layer = layers.RNN([cell1, cell2])
    with self.assertRaises(AttributeError):
      layer.get_clusterable_weights()

    ClusterRegistry.make_clusterable(layer)
    keras.Sequential([layer]).build(input_shape=(2, 3, 4))

    expected_weights = [
        ('kernel', cell1.kernel),
        ('recurrent_kernel', cell1.recurrent_kernel),
        ('kernel', cell2.kernel),
        ('recurrent_kernel', cell2.recurrent_kernel)
    ]
    self.assertEqual(expected_weights, layer.get_clusterable_weights()) 
Example #26
Source File: prune_test.py    From model-optimization with Apache License 2.0 6 votes vote down vote up
def testPruneSequentialModel(self):
    # No InputLayer
    model = keras.Sequential([
        layers.Dense(10),
        layers.Dense(10),
    ])
    pruned_model = prune.prune_low_magnitude(model, **self.params)
    self.assertEqual(self._count_pruned_layers(pruned_model), 2)

    # With InputLayer
    model = keras.Sequential([
        layers.Dense(10, input_shape=(10,)),
        layers.Dense(10),
    ])
    pruned_model = prune.prune_low_magnitude(model, **self.params)
    self.assertEqual(self._count_pruned_layers(pruned_model), 2) 
Example #27
Source File: cluster_test.py    From model-optimization with Apache License 2.0 5 votes vote down vote up
def testStripSelectivelyClusteredSequentialModel(self):
    """
    Verifies that invoking strip_clustering() on a selectively clustered
    sequential model strips the clustering wrappers from the clustered layers.
    """
    clustered_model = keras.Sequential([
      cluster.cluster_weights(layers.Dense(10), **self.params),
      layers.Dense(10),
    ])
    clustered_model.build(input_shape=(1, 10))

    stripped_model = cluster.strip_clustering(clustered_model)

    self.assertEqual(self._count_clustered_layers(stripped_model), 0)
    self.assertIsInstance(stripped_model.layers[0], layers.Dense) 
Example #28
Source File: clustering_registry_test.py    From model-optimization with Apache License 2.0 5 votes vote down vote up
def testSupportsKerasRNNLayerClusterableCell(self):
    """
    Verifies that ClusterRegistry supports a custom clusterable RNN cell.
    """
    self.assertTrue(ClusterRegistry.supports(
        keras.layers.RNN(ClusterRegistryTest.MinimalRNNCellClusterable(32)))) 
Example #29
Source File: cluster_test.py    From model-optimization with Apache License 2.0 5 votes vote down vote up
def testClusterSubclassModel(self):
    """
    Verifies that attempting to cluster an instance of a subclass of
    keras.Model raises an exception.
    """
    model = TestModel()
    with self.assertRaises(ValueError):
      _ = cluster.cluster_weights(model, **self.params) 
Example #30
Source File: clustering_registry_test.py    From model-optimization with Apache License 2.0 5 votes vote down vote up
def testMakeClusterableWorksOnKerasClusterableLayer(self):
    """
    Verifies that make_clusterable() works as expected on a built-in
    clusterable layer.
    """
    layer = layers.Dense(10)
    with self.assertRaises(AttributeError):
      layer.get_clusterable_weights()

    ClusterRegistry.make_clusterable(layer)
    # Required since build method sets up the layer weights.
    keras.Sequential([layer]).build(input_shape=(10, 1))

    self.assertEqual([('kernel', layer.kernel)],
                     layer.get_clusterable_weights())