# -*- coding: utf-8 -*- import tensorflow as tf from distribution_base import DistributionBase class CategoricalDistribution(DistributionBase): def __init__(self, counts, means=None): self.dims = len(counts) self.counts = counts self.means = means self.tf_means = None def initialize(self, dtype=tf.float64): if self.tf_means is None: self.tf_means = [] for dim in range(self.dims): if self.means is not None: tf_mean = tf.Variable(self.means[dim], dtype=dtype) else: tf_rand = tf.random_uniform([self.counts[dim]], maxval=1.0, dtype=dtype) tf_mean = tf.Variable(tf_rand / tf.reduce_sum(tf_rand)) self.tf_means.append(tf_mean) def get_parameters(self): return self.tf_means def get_log_probabilities(self, data): tf_log_probabilities = [] for dim in range(self.dims): tf_log_means = tf.log(self.tf_means[dim]) tf_log_probabilities.append( tf.gather(tf_log_means, data[0][:, dim]) ) return tf.reduce_sum(tf.parallel_stack(tf_log_probabilities), axis=0) def get_parameter_updaters(self, data, gamma_weighted, gamma_sum): tf_parameter_updaters = [] for dim in range(self.dims): tf_partition = tf.dynamic_partition(gamma_weighted, data[0][:, dim], self.counts[dim]) tf_new_means = tf.parallel_stack([tf.reduce_sum(p) for p in tf_partition]) tf_parameter_updaters.append(self.tf_means[dim].assign(tf_new_means)) return tf_parameter_updaters