"""We add metrics specific to extremely quantized networks using a `larq.context.metrics_scope` rather than through the `metrics` parameter of `model.compile()`, where most common metrics reside. This is because, to calculate metrics like the `flip_ratio`, we need a layer's kernel or activation and not just the `y_true` and `y_pred` that Keras passes to metrics defined in the usual way. """ import numpy as np import tensorflow as tf from larq import utils @utils.register_alias("flip_ratio") @utils.register_keras_custom_object class FlipRatio(tf.keras.metrics.Metric): """Computes the mean ratio of changed values in a given tensor. !!! example ```python m = metrics.FlipRatio() m.update_state((1, 1)) # result: 0 m.update_state((2, 2)) # result: 1 m.update_state((1, 2)) # result: 0.75 print('Final result: ', m.result().numpy()) # Final result: 0.75 ``` # Arguments name: Name of the metric. values_dtype: Data type of the tensor for which to track changes. dtype: Data type of the moving mean. """ def __init__(self, values_dtype="int8", name="flip_ratio", dtype=None): super().__init__(name=name, dtype=dtype) self.built = False self.values_dtype = tf.as_dtype(values_dtype) def build(self, input_shape): self._previous_values = self.add_weight( "previous_values", shape=input_shape, dtype=self.values_dtype, initializer=tf.keras.initializers.zeros, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) self.total = self.add_weight( "total", initializer=tf.keras.initializers.zeros, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) self.count = self.add_weight( "count", initializer=tf.keras.initializers.zeros, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) self._size = tf.cast(np.prod(input_shape), self.dtype) self.built = True def update_state(self, values, sample_weight=None): values = tf.cast(values, self.values_dtype) if not self.built: with tf.name_scope(self.name), tf.init_scope(): self.build(values.shape) unchanged_values = tf.math.count_nonzero( tf.equal(self._previous_values, values) ) flip_ratio = 1 - ( tf.cast(unchanged_values, self.dtype) / tf.cast(self._size, self.dtype) ) update_total_op = self.total.assign_add(flip_ratio * tf.sign(self.count)) with tf.control_dependencies([update_total_op]): update_count_op = self.count.assign_add(1) with tf.control_dependencies([update_count_op]): return self._previous_values.assign(values) def result(self): return tf.compat.v1.div_no_nan(self.total, self.count - 1) def reset_states(self): tf.keras.backend.batch_set_value( [(v, 0) for v in self.variables if v is not self._previous_values] ) def get_config(self): return {**super().get_config(), "values_dtype": self.values_dtype.name}