# -*- coding: utf-8 -*- # Copyright (c) 2015-2016 MIT Probabilistic Computing Project # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import numpy as np from cgpm.mixtures.view import View from cgpm.utils import timer as tu def _crosscat_M_c(state): """Create M_c from cgpm.state.State""" T = state.X outputs = state.outputs cctypes = state.cctypes() distargs = state.distargs() assert len(T) == len(outputs) == len(cctypes) == len(distargs) assert all(c in ['normal', 'categorical'] for c in cctypes) ncols = len(outputs) def create_metadata_numerical(): return { unicode('modeltype'): unicode('normal_inverse_gamma'), unicode('value_to_code'): {}, unicode('code_to_value'): {}, } def create_metadata_categorical(col, k): categories = [v for v in sorted(set(T[col])) if not np.isnan(v)] assert all(0 <= c < k for c in categories) codes = [unicode('%d') % (c,) for c in categories] ncodes = range(len(codes)) return { unicode('modeltype'): unicode('symmetric_dirichlet_discrete'), unicode('value_to_code'): dict(zip(map(unicode, ncodes), codes)), unicode('code_to_value'): dict(zip(codes, ncodes)), } column_names = [unicode('c%d') % (i,) for i in outputs] # Convert all numerical datatypes to normal for lovecat. column_metadata = [ create_metadata_numerical() if cctype != 'categorical' else\ create_metadata_categorical(output, distarg['k']) for output, cctype, distarg in zip(outputs, cctypes, distargs) ] return { unicode('name_to_idx'): dict(zip(column_names, range(ncols))), unicode('idx_to_name'): dict(zip(map(unicode, range(ncols)), column_names)), unicode('column_metadata'): column_metadata, } def _crosscat_T(state, M_c): """Create T from cgpm.state.State""" T = state.X def crosscat_value_to_code(val, col): if np.isnan(val): return val # For hysterical raisins, code_to_value and value_to_code are # backwards, so to convert from a raw value to a crosscat value we # need to do code->value. lookup = M_c['column_metadata'][col]['code_to_value'] if lookup: assert unicode(int(val)) in lookup return float(lookup[unicode(int(val))]) else: return val ordering = state.outputs rows = range(len(T[ordering[0]])) return [ [crosscat_value_to_code(T[col][row], i) for (i, col) in enumerate(ordering)] for row in rows ] def _crosscat_X_D(state, M_c): """Create X_D from cgpm.state.State""" view_assignments = state.Zv().values() views_unique = sorted(set(view_assignments)) cluster_assignments = [ state.views[v].Zr().values() for v in views_unique ] cluster_assignments_unique = [ sorted(set(assgn)) for assgn in cluster_assignments ] cluster_assignments_to_code = [ {k:i for (i,k) in enumerate(assgn)} for assgn in cluster_assignments_unique ] cluster_assignments_remapped = [ [coder[v] for v in assgn] for (coder, assgn) in zip(cluster_assignments_to_code, cluster_assignments) ] # cluster_assignments_remapped[i] contains the row partition for the # views_unique[i]. return cluster_assignments_remapped def _crosscat_X_L(state, M_c, X_D): """Create X_L from cgpm.state.State""" # -- Generates X_L['column_hypers'] -- def column_hypers_numerical(index, hypers): assert state.cctypes()[index] != 'categorical' return { unicode('fixed'): 0.0, unicode('mu'): hypers['m'], unicode('nu'): hypers['nu'], unicode('r'): hypers['r'], unicode('s'): hypers['s'], } def column_hypers_categorical(index, hypers): assert state.cctypes()[index] == 'categorical' K = len(M_c['column_metadata'][index]['code_to_value']) assert K > 0 return { unicode('fixed'): 0.0, unicode('dirichlet_alpha'): hypers['alpha'], unicode('K'): K } # Retrieve the column_hypers. column_hypers = [ column_hypers_numerical(i, state.dims()[i].hypers) if cctype != 'categorical' else column_hypers_categorical(i, state.dims()[i].hypers) for i, cctype in enumerate(state.cctypes()) ] # -- Generates X_L['column_partition'] -- view_assignments = state.Zv().values() views_unique = sorted(set(view_assignments)) views_to_code = {v:i for (i,v) in enumerate(views_unique)} # views_remapped[i] contains the zero-based view index for # state.outputs[i]. views_remapped = [views_to_code[state.Zv(o)] for o in state.outputs] counts = list(np.bincount(views_remapped)) assert 0 not in counts column_partition = { unicode('assignments'): views_remapped, unicode('counts'): counts, unicode('hypers'): {unicode('alpha'): state.alpha()} } # -- Generates X_L['view_state'] -- def view_state(v): view = state.views[v] row_partition = X_D[views_to_code[v]] # Generate X_L['view_state'][v]['column_component_suffstats'] numcategories = len(set(row_partition)) column_component_suffstats = [ [{} for c in xrange(numcategories)] for d in view.dims] # Generate X_L['view_state'][v]['column_names'] column_names = \ [unicode('c%d' % (o,)) for o in view.outputs[1:]] # Generate X_L['view_state'][v]['row_partition_model'] counts = list(np.bincount(row_partition)) assert 0 not in counts return { unicode('column_component_suffstats'): column_component_suffstats, unicode('column_names'): column_names, unicode('row_partition_model'): { unicode('counts'): counts, unicode('hypers'): {unicode('alpha'): view.alpha()} } } # view_states[i] is the view for code views_to_code[i], so we need to # iterate in the same order of views_unique to agree with both X_D (the row # partition in each view), as well as X_L['column_partition']['assignments'] view_states = [view_state(v) for v in views_unique] # Generates X_L['col_ensure']. col_ensure = dict() if state.Cd: col_ensure['dependent'] = { str(column) : list(block) for block in state.Cd for column in block } if state.Ci: from crosscat.utils.general_utils import get_scc_from_tuples col_ensure['independent'] = { str(column) : list(block) for column, block in get_scc_from_tuples(state.Ci).iteritems() } return { unicode('column_hypers'): column_hypers, unicode('column_partition'): column_partition, unicode('view_state'): view_states, unicode('col_ensure'): col_ensure } def _update_state(state, M_c, X_L, X_D): # Perform checking on M_c. assert all(c in ['normal','categorical'] for c in state.cctypes()) assert len(M_c['name_to_idx']) == len(state.outputs) def _check_model_type(i): reference = 'normal_inverse_gamma' if state.cctypes()[i] == 'normal'\ else 'symmetric_dirichlet_discrete' return M_c['column_metadata'][i]['modeltype'] == reference assert all(_check_model_type(i) for i in xrange(len(state.cctypes()))) # Perform checking on X_D. assert all(len(partition)==state.n_rows() for partition in X_D) assert len(X_D) == len(X_L['view_state']) # Perform checking on X_L. assert len(X_L['column_partition']['assignments']) == len(state.outputs) # Update the global state alpha. state.crp.set_hypers( {'alpha': X_L['column_partition']['hypers']['alpha']} ) assert state.alpha() == X_L['column_partition']['hypers']['alpha'] assert state.crp.clusters[0].alpha ==\ X_L['column_partition']['hypers']['alpha'] # Create the new views. offset = max(state.views) + 1 new_views = [] for v in xrange(len(X_D)): alpha = X_L['view_state'][v]['row_partition_model']['hypers']['alpha'] index = v + offset assert index not in state.views view = View( state.X, outputs=[state.crp_id_view + index], Zr=X_D[v], alpha=alpha, rng=state.rng ) new_views.append(view) state._append_view(view, index) # Migrate the dims to their view partitions. for i, c in enumerate(state.outputs): v_a = state.Zv(c) v_b = X_L['column_partition']['assignments'][i] + offset state._migrate_dim(v_a, v_b, state.dim_for(c)) # Update the dim hyperparameters. # This code is disabled because lovecat may give hypers which result in # math domain errors! # for i, c in enumerate(state.outputs): # dim = state.dim_for(c) # if dim.cctype == 'categorical': # dim.hypers['alpha'] = X_L['column_hypers'][i]['dirichlet_alpha'] # elif dim.cctype == 'normal': # dim.hypers['m'] = X_L['column_hypers'][i]['mu'] # dim.hypers['r'] = X_L['column_hypers'][i]['r'] # dim.hypers['s'] = X_L['column_hypers'][i]['s'] # dim.hypers['nu'] = X_L['column_hypers'][i]['nu'] # else: # assert False assert len(state.views) == len(new_views) state._check_partitions() def _update_diagnostics(state, diagnostics): # Update logscore. cc_logscore = diagnostics.get('logscore', np.array([])) new_logscore = map(float, np.ravel(cc_logscore).tolist()) state.diagnostics['logscore'].extend(new_logscore) # Update column_crp_alpha. cc_column_crp_alpha = diagnostics.get('column_crp_alpha', []) new_column_crp_alpha = map(float, np.ravel(cc_column_crp_alpha).tolist()) state.diagnostics['column_crp_alpha'].extend(list(new_column_crp_alpha)) # Update column_partition. def convert_column_partition(assignments): return [ (col, int(assgn)) for col, assgn in zip(state.outputs, assignments) ] new_column_partition = diagnostics.get('column_partition_assignments', []) if len(new_column_partition) > 0: assert len(new_column_partition) == len(state.outputs) trajectories = np.transpose(new_column_partition)[0].tolist() state.diagnostics['column_partition'].extend( map(convert_column_partition, trajectories)) def _progress(n_steps, max_time, step_idx, elapsed_secs, end=None): if end: print '\rCompleted: %d iterations in %f seconds.' %\ (step_idx, elapsed_secs) else: p_seconds = elapsed_secs / max_time if max_time != -1 else 0 p_iters = float(step_idx) / n_steps percentage = max(p_iters, p_seconds) tu.progress(percentage, sys.stdout) def transition( state, N=None, S=None, kernels=None, rowids=None, cols=None, seed=None, checkpoint=None, progress=None): """Runs full Gibbs sweeps of all kernels on the cgpm.state.State object. Permittable kernels: 'column_partition_hyperparameter' 'column_partition_assignments' 'column_hyperparameters' 'row_partition_hyperparameters' 'row_partition_assignments' """ if seed is None: seed = 1 if kernels is None: kernels = () if (progress is None) or progress: progress = _progress if N is None and S is None: n_steps = 1 max_time = -1 if N is not None and S is None: n_steps = N max_time = -1 elif S is not None and N is None: # This is a hack, lovecat has no way to specify just max_seconds. n_steps = 150000 max_time = S elif S is not None and N is not None: n_steps = N max_time = S else: assert False if cols is None: cols = () else: cols = [state.outputs.index(i) for i in cols] if rowids is None: rowids = () M_c = _crosscat_M_c(state) T = _crosscat_T(state, M_c) X_D = _crosscat_X_D(state, M_c) X_L = _crosscat_X_L(state, M_c, X_D) from crosscat.LocalEngine import LocalEngine LE = LocalEngine(seed=seed) if checkpoint is None: X_L_new, X_D_new = LE.analyze( M_c, T, X_L, X_D, seed, kernel_list=kernels, n_steps=n_steps, max_time=max_time, c=cols, r=rowids, progress=progress) diagnostics_new = dict() else: X_L_new, X_D_new, diagnostics_new = LE.analyze( M_c, T, X_L, X_D, seed, kernel_list=kernels, n_steps=n_steps, max_time=max_time, c=cols, r=rowids, do_diagnostics=True, diagnostics_every_N=checkpoint, progress=progress) _update_state(state, M_c, X_L_new, X_D_new) if diagnostics_new: _update_diagnostics(state, diagnostics_new)