#!/usr/bin/env python __author__ = 'Florian Hase' #======================================================================== import theano import theano.tensor as T import numpy as np import pymc3 as pm from Utils.utils import VarDictParser from BayesianNeuralNetwork.distributions import DiscreteLaplace #======================================================================== class Pymc3Network(VarDictParser): def __init__(self, var_dicts, observed_params, observed_losses, batch_size, model_details): VarDictParser.__init__(self, var_dicts) self.observed_params = observed_params self.observed_losses = observed_losses self.num_obs = len(self.observed_losses) self.batch_size = batch_size self.model_details = model_details for key, value in self.model_details.items(): setattr(self, str(key), value) self._process_network_inputs() self._get_weights_and_bias_shapes() def __get_weights(self, index, shape, scale = None): return pm.Normal('w%d' % index, self.weight_loc, self.weight_scale, shape = shape) def __get_biases(self, index, shape, scale = None): return pm.Normal('b%d' % index, self.weight_loc, self.weight_scale, shape = shape) def weight(self, index): return getattr(self, 'w%d' % index) def bias(self, index): return getattr(self, 'b%d' % index) def _get_weights_and_bias_shapes(self): self.weight_shapes = [[self.network_input.shape[1], self.hidden_shape]] self.bias_shapes = [[self.hidden_shape]] for index in range(1, self.num_layers - 1): self.weight_shapes.append([self.hidden_shape, self.hidden_shape]) self.bias_shapes.append([self.hidden_shape]) self.weight_shapes.append([self.hidden_shape, self.network_input.shape[1]]) self.bias_shapes.append([self.network_input.shape[1]]) def _process_network_inputs(self): self.network_input = np.zeros((self.num_obs, self.complete_size)) #+ 10.**-4 self.network_output = np.zeros((self.num_obs, self.total_size)) for obs_index, obs in enumerate(self.observed_params): current_index = 0 for var_index, value in enumerate(obs): if self.var_p_types[var_index] == 'categorical': self.network_input[obs_index, int(current_index + value)] += 1. #- 2 * 10.**-4 self.network_output[obs_index, var_index] = value current_index += len(self.var_p_options[var_index]) else: self.network_input[obs_index, current_index] = value self.network_output[obs_index, var_index] = value current_index += 1 for att in ['floats', 'ints', 'cats']: setattr(self, att, np.array([False for i in range(self.complete_size)])) self.upper_rescalings = np.empty(self.complete_size) self.lower_rescalings = np.empty(self.complete_size) for var_e_index, var_e_name in enumerate(self.var_e_names): high = self.var_e_highs[var_e_index] low = self.var_e_lows[var_e_index] if self.var_e_types[var_e_index] == 'float': self.upper_rescalings[var_e_index] = high + 0.1 * (high - low) self.lower_rescalings[var_e_index] = low - 0.1 * (high - low) self.floats[var_e_index] = True elif self.var_e_types[var_e_index] == 'integer': self.upper_rescalings[var_e_index] = high# + np.ceil(0.1 * (high - low)) self.lower_rescalings[var_e_index] = low# - np.ceil(0.1 * (high - low)) self.ints[var_e_index] = True elif self.var_e_types[var_e_index] == 'categorical': self.upper_rescalings[var_e_index] = 1. self.lower_rescalings[var_e_index] = 0. self.cats[var_e_index] = True self.network_input = 2. * (self.network_input - self.lower_rescalings) / (self.upper_rescalings - self.lower_rescalings) - 1. def _create_model(self): with pm.Model() as self.model: # getting the location primers for layer_index in range(self.num_layers): setattr(self, 'w%d' % layer_index, self.__get_weights(layer_index, self.weight_shapes[layer_index])) setattr(self, 'b%d' % layer_index, self.__get_biases(layer_index, self.bias_shapes[layer_index])) if layer_index == 0: fc = pm.Deterministic('fc%d' % layer_index, pm.math.tanh(pm.math.dot(self.network_input, self.weight(layer_index)) + self.bias(layer_index))) setattr(self, 'fc%d' % layer_index, fc) elif 0 < layer_index < self.num_layers - 1: fc = pm.Deterministic('fc%d' % layer_index, pm.math.tanh(pm.math.dot(getattr(self, 'fc%d' % (layer_index - 1)), self.weight(layer_index)) + self.bias(layer_index))) setattr(self, 'fc%d' % layer_index, fc) else: self._loc = pm.Deterministic('bnn_out', pm.math.sigmoid(pm.math.dot(getattr(self, 'fc%d' % (layer_index - 1)), self.weight(layer_index)) + self.bias(layer_index)) ) # getting the precision / standard deviation / variance self.tau_rescaling = np.zeros((self.num_obs, self.network_input.shape[1])) for obs_index in range(self.num_obs): self.tau_rescaling[obs_index] += self.var_e_ranges self.tau_rescaling = self.tau_rescaling**2 tau = pm.Gamma('tau', self.num_obs**2, 1., shape = (self.num_obs, self.network_input.shape[1])) self.tau = tau / self.tau_rescaling self.scale = pm.Deterministic('scale', 1. / pm.math.sqrt(self.tau)) # learn the floats self.loc = pm.Deterministic('loc', (self.upper_rescalings - self.lower_rescalings) * self._loc + self.lower_rescalings) self.out_floats = pm.Normal('out_floats', self.loc[:, self.floats], tau = self.tau[:, self.floats], observed = self.network_output[:, self._floats]) # learn the integers self.int_scale = pm.Deterministic('int_scale', 1. * self.scale) self.out_ints = DiscreteLaplace('out_ints', loc = self.loc[:, self.ints], scale = self.int_scale[:, self.ints], observed = self.network_output[:, self._ints]) # learn the categories dist_counter, cat_var_index = 0, 0 self.alpha = pm.Deterministic('alpha', (self.loc + 1.) * self.scale) self.num_cats = 0 for var_e_index, var_e_type in enumerate(self.var_e_types): if var_e_type == 'categorical' and self.var_e_begin[var_e_index] == var_e_index: begin, end = self.var_e_begin[var_e_index], self.var_e_end[var_e_index] var_e_name = self.var_e_names[var_e_index] param_index = np.argwhere(self.var_p_names == var_e_name)[0, 0] self.param_index = param_index out_dirichlet = pm.Dirichlet('dirich_%d' % dist_counter, a = self.alpha[:, begin : end], shape = (self.num_obs, int(end - begin)) ) out_cats = pm.Categorical('out_cats_%d' % dist_counter, p = out_dirichlet, observed = self.network_output[:, param_index]) self.num_cats += 1 dist_counter += 1 def _sample(self, num_epochs = None, num_draws = None): if not num_epochs: num_epochs = self.num_epochs if not num_draws: num_draws = self.num_draws with self.model: # approx = pm.fit(method = 'svgd', n = 1000, obj_optimizer = pm.adam(learning_rate = self.learning_rate)) # self.trace = approx.sample(draws = num_draws) approx = pm.fit(n = num_epochs, obj_optimizer = pm.adam(learning_rate = self.learning_rate)) self.trace = approx.sample(draws = num_draws)