From 28eb6b48f9759a843253f6a5f69ad0e9a15d4a0a Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Mon, 22 Jan 2024 04:26:25 -0800 Subject: [PATCH] Add more TFv1 variable types. PiperOrigin-RevId: 600417849 --- ffn/training/variables.py | 91 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/ffn/training/variables.py b/ffn/training/variables.py index eb4d993..904a5d6 100644 --- a/ffn/training/variables.py +++ b/ffn/training/variables.py @@ -14,13 +14,14 @@ # ============================================================================== """Customized variables for tracking ratios, rates, etc.""" +import numpy as np import tensorflow.compat.v1 as tf -class FractionTracker(object): +class FractionTracker: """Helper for tracking fractions.""" - def __init__(self, name='fraction'): + def __init__(self, name: str = 'fraction'): # Values are: total, hits. self.var = tf.get_variable(name, [2], tf.int64, tf.constant_initializer([0, 0]), trainable=False) @@ -42,3 +43,89 @@ def get_hit_rate(self): update_var = self.var.assign_add([-total, -hits]) with tf.control_dependencies([update_var]): return tf.identity(hit_rate) + + +class DistributionTracker: + """Helper for tracking distributions.""" + + def __init__(self, num_classes: int, name: str = 'distribution'): + self.num_classes = num_classes + self.var = tf.get_variable( + name, [num_classes], + tf.int64, + tf.constant_initializer([0] * num_classes), + trainable=False) + + def record_class(self, class_id, count=1): + return self.var.assign_add( + tf.one_hot(class_id, self.num_classes, dtype=tf.int64) * count) + + def record_classes(self, labels): + delta = tf.math.bincount( + labels, + minlength=self.num_classes, + maxlength=self.num_classes, + dtype=tf.int64) + return self.var.assign_add(delta) + + def get_rates(self, reset=True): + """Queries the class frequencies. + + Args: + reset: whether to reset the class counters to 0 after query + + Returns: + TF op for class frequencies + """ + total = tf.reduce_sum(self.var) + rates = tf.cast(self.var, tf.float32) / tf.maximum( + tf.constant(1, dtype=tf.float32), tf.cast(total, tf.float32)) + if not reset: + return rates + + with tf.control_dependencies([rates]): + update_var = self.var.assign_add(-self.var) + with tf.control_dependencies([update_var]): + return tf.identity(rates) + + +def get_and_reset_value(var): + readout = var + 0 + with tf.control_dependencies([readout]): + update_var = var.assign_add(-readout) + with tf.control_dependencies([update_var]): + return tf.identity(readout) + + +class TFSyncVariable: + """A local variable which can be periodically synchronized to a TF one.""" + + def __init__(self, name, shape, dtype): + self._value = np.zeros(shape, dtype=dtype.as_numpy_dtype) + self._tf_var = tf.get_variable( + name, + shape, + dtype, + tf.constant_initializer(self.value), + trainable=False) + self._update_placeholder = tf.placeholder( + dtype, shape, name='plc_%s' % name) + self._to_tf = self._tf_var.assign_add(self._update_placeholder) + self._from_tf = get_and_reset_value(self._tf_var) + self.tf_value = None + + @property + def from_tf(self): + return self._from_tf + + @property + def value(self): + return self._value + + def to_tf(self, ops, feed_dict): + ops.append(self._to_tf) + feed_dict[self._update_placeholder] = self._value + self._value = np.zeros_like(self._value) + + def reset(self): + self._value = np.zeros_like(self._value)