```# -*- coding: utf-8 -*-

# Copyright (c) 2015-2016 MIT Probabilistic Computing Project

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import math

import numpy as np

from scipy.stats import norm

from cgpm.crosscat.engine import Engine
from cgpm.crosscat.state import State
from cgpm.mixtures.dim import Dim
from cgpm.mixtures.view import View
from cgpm.utils import general as gu

def gen_data_table(n_rows, view_weights, cluster_weights, cctypes, distargs,
separation, view_partition=None, rng=None):
"""Generates data, partitions, and Dim.

Parameters
----------
n_rows : int
Mumber of rows (data points) to generate.
view_weights : list<float>
An n_views length list of floats that sum to one. The weights indicate
the proportion of columns in each view.
cluster_weights : list<list<float>>
An n_views length list of n_cluster length lists that sum to one.
The weights indicate the proportion of rows in each cluster.
cctypes : list<str>
n_columns length list of string specifying the distribution types for
each column.
distargs : list
List of distargs for each column (see documentation for each data type
for info on distargs).
separation : list
An n_cols length list of values between [0,1], where seperation[i] is
the seperation of clusters in column i. Values closer to 1 imply higher
seperation.

Returns
-------
T : np.ndarray
An (n_cols, n_rows) matrix, where each row T[i,:] is the data for
column i (tranpose of a design matrix).
Zv : list
An n_cols length list of integers, where Zv[i] is the view assignment
of column i.
Zc : list<list>
An n_view length list of lists, where Zc[v][r] is the cluster assignment
of row r in view v.

Example
-------
>>> n_rows = 500
>>> view_weights = [.2, .8]
>>> cluster_weights = [[.3, .2, .5], [.4, .6]]
>>> cctypes = ['lognormal','normal','poisson','categorical',
...     'vonmises', 'bernoulli']
>>> distargs = [None, None, None, {'k':8}, None, None]
>>> separation = [.8, .7, .9, .6, .7, .85]
>>> T, Zv, Zc, dims = tu.gen_data_table(n_rows, view_weights,
...     cluster_weights, dists, distargs, separation)
"""
if rng is None:
rng = gu.gen_rng()

n_cols = len(cctypes)

if view_partition:
Zv = list(view_partition)
else:
Zv = gen_partition(n_cols, view_weights, rng)

Zc = [gen_partition(n_rows, cw, rng) for cw in cluster_weights]

assert len(Zv) == n_cols
assert len(Zc) == len(set(Zv))
assert len(Zc[0]) == n_rows

T = np.zeros((n_cols, n_rows))

for col in xrange(n_cols):
cctype = cctypes[col]
args = distargs[col]
view = Zv[col]
Tc = _gen_data[cctype](
Zc[view], rng, separation=separation[col], distargs=args)
T[col] = Tc

return T, Zv, Zc

def gen_dims_from_structure(T, Zv, Zc, cctypes, distargs):
n_cols = len(Zv)
dims = []
for col in xrange(n_cols):
v = Zv[col]
cctype = cctypes[col]
dim_c = Dim(cctype, col, distargs=distargs[col])
dim_c.transition_hyper_grids(T[col])
dim_c.bulk_incorporate(T[col], Zc[v])
dims.append(dim_c)
return dims

def _gen_beta_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)

K = np.max(Z)+1
alphas = np.linspace(.5 - .5*separation*.85, .5 + .5*separation*.85, K)
Tc = np.zeros(n_rows)

for r in xrange(n_rows):
cluster = Z[r]
alpha = alphas[cluster]
beta = (1.-alpha) * 20.* (norm.pdf(alpha, .5, .25))
alpha *= 20. * norm.pdf(alpha, .5, .25)
Tc[r] = rng.beta(alpha, beta)

return Tc

def _gen_normal_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)

Tc = np.zeros(n_rows)
for r in xrange(n_rows):
cluster = Z[r]
mu = cluster * (5.*separation)
sigma = 1.0
Tc[r] = rng.normal(loc=mu, scale=sigma)

return Tc

def _gen_normal_trunc_data(Z, rng, separation=.9, distargs=None):
l, h = distargs['l'], distargs['h']
max_draws = 100
n_rows = len(Z)

K = max(Z) + 1
mean = (l+h)/2.

bins = np.linspace(l, h, K+1)
bin_centers = [.5*(bins[i-1]+bins[i]) for i in xrange(1, len(bins))]
distances = [mean - bc for bc in bin_centers]
mus = [bc + (1-separation)*d for bc, d in zip(bin_centers, distances)]

Tc = np.zeros(n_rows)
for r in xrange(n_rows):
cluster = Z[r]
sigma = 1
i = 0
while True:
i += 1
x = rng.normal(loc=mus[cluster], scale=sigma)
if l <= x <= h:
break
if max_draws < i:
raise ValueError('Could not generate normal_trunc data.')
Tc[r] = x

return Tc

def _gen_vonmises_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)

num_clusters = max(Z)+1
sep = 2*math.pi / num_clusters

mus = [c*sep for c in xrange(num_clusters)]
std = sep/(5.*separation**.75)
k = 1 / (std*std)

Tc = np.zeros(n_rows)
for r in xrange(n_rows):
cluster = Z[r]
mu = mus[cluster]
Tc[r] = rng.vonmises(mu, k) + math.pi

return Tc

def _gen_poisson_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)
Tc = np.zeros(n_rows)

for r in xrange(n_rows):
cluster = Z[r]
lam = cluster * (4.*separation) + 1
Tc[r] = rng.poisson(lam)

return Tc

def _gen_exponential_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)
Tc = np.zeros(n_rows)

for r in xrange(n_rows):
cluster = Z[r]
mu = cluster * (4.*separation) + 1
Tc[r] = rng.exponential(mu)

return Tc

def _gen_geometric_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)
Tc = np.zeros(n_rows)
K = np.max(Z)+1

ps = np.linspace(.5 - .5*separation*.85, .5 + .5*separation*.85, K)
Tc = np.zeros(n_rows)
for r in xrange(n_rows):
cluster = Z[r]
Tc[r] = rng.geometric(ps[cluster]) -1

return Tc

def _gen_lognormal_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)

if separation > .9:
separation = .9

Tc = np.zeros(n_rows)
for r in xrange(n_rows):
cluster = Z[r]
mu = cluster * (.9*separation**2)
Tc[r] = rng.lognormal(mean=mu, sigma=(1.-separation)/(cluster+1.))

return Tc

def _gen_bernoulli_data(Z, rng, separation=.9, distargs=None):
n_rows = len(Z)

Tc = np.zeros(n_rows)
K = max(Z)+1
thetas = np.linspace(0., separation, K)

for r in range(n_rows):
cluster = Z[r]
theta = thetas[cluster]
x = 0.0
if rng.rand() < theta:
x = 1.0
Tc[r] = x

return Tc

def _gen_categorical_data(Z, rng, separation=.9, distargs=None):
k = int(distargs['k'])
n_rows = len(Z)

if separation > .95:
separation = .95

Tc = np.zeros(n_rows, dtype=int)
C = max(Z)+1
theta_arrays = [rng.dirichlet(np.ones(k)*(1.-separation), 1)
for _ in range(C)]

for r in xrange(n_rows):
cluster = Z[r]
thetas = theta_arrays[cluster][0]
x = gu.pflip(thetas, rng=rng)
Tc[r] = int(x)
return Tc

def gen_partition(N, weights, rng):
assert all(w != 0 for w in weights)
assert np.allclose(sum(weights), 1)
K = len(weights)
assert K <= N    # XXX FIXME
Z = range(K)
Z.extend(int(gu.pflip(weights, rng=rng)) for _ in xrange(N-K))
rng.shuffle(Z)
return Z

def column_average_ari(Zv, Zc, cc_state_object):
ari = 0
n_cols = len(Zv)
for col in xrange(n_cols):
view_t = Zv[col]
Zc_true = Zc[view_t]

view_i = cc_state_object.Zv[col]
Zc_inferred = cc_state_object.views[view_i].Z.tolist()

return ari/float(n_cols)

def gen_simple_engine(multiprocess=1):
data = np.array([[1, 1, 1]])
R = len(data)
D = len(data[0])
outputs = range(D)
engine = Engine(
X=data,
num_states=20,
rng=gu.gen_rng(1),
multiprocess=multiprocess,
outputs=outputs,
alpha=1.,
cctypes=['bernoulli']*D,
distargs={i: {'alpha': 1., 'beta': 1.} for i in outputs},
Zv={0: 0, 1: 0, 2: 1},
view_alphas=[1.]*D,
Zrv={0: [0]*R, 1: [0]*R})
return engine

def gen_simple_state():
data = np.array([[1, 1, 1]])
R = len(data)
D = len(data[0])
outputs = range(D)
state = State(
X=data,
outputs=outputs,
alpha=1.,
cctypes=['bernoulli']*D,
hypers=[{'alpha': 1., 'beta': 1.} for i in outputs],
Zv={0: 0, 1: 0, 2: 1},
view_alphas=[1.]*D,
Zrv={0: [0]*R, 1: [0]*R})
return state

def gen_simple_view():
data = np.array([[1, 1]])
R = len(data)
D = len(data[0])
outputs = range(D)
X = {c: data[:, i].tolist() for i, c in enumerate(outputs)}
Zr = [0 for i in range(R)]
view = View(
X,
outputs=[1000] + outputs,
alpha=1.,
cctypes=['bernoulli']*D,
hypers={i: {'alpha': 1., 'beta': 1.} for i in outputs},
Zr=Zr)
return view

def gen_multitype_view():
data = np.array([[0, 2, 3.14],
[1, 1, 1]])
D = len(data[0])
dpm_outputs = range(D)
X = {c: data[:, i].tolist() for i, c in enumerate(dpm_outputs)}
crp_alpha = 1.
cctypes = ['bernoulli', 'categorical', 'normal']
hypers = {
0: {'alpha': 1., 'beta': 1.},
1: {'alpha': 1.},
2: {'m': 0, 'r': 1., 's': 1., 'nu': 1.}}
distargs = {0: {}, 1: {'k': 3}, 2: {}}
Zr = [0, 1]
view = View(
X,
outputs=[1000] + dpm_outputs,
alpha=crp_alpha,
cctypes=cctypes,
hypers=hypers,
distargs=distargs,
Zr=Zr)
return view

def change_column_hyperparameters(cgpm, value):
"""Generate new cgpm from input cgpm, and alter its distribution
hyperparameters to extreme values.
"""

# Alter metadata.hypers to extreme values according to data type
if isinstance(cgpm, View):
columns.pop(-1)  # first output is exposed latent

new_hypers = [{} for c in columns]
for c in columns:
if cctypes[c] == 'bernoulli':
new_hypers[c] = {'alpha': value, 'beta': value}

elif cctypes[c] == 'categorical':
new_hypers[c] = {'alpha': value}

elif cctypes[c] == 'normal':
new_hypers[c] = {'m': value, 'nu': value, 'r': value, 's': value}

else:
ValueError('''
cctype not recognized (not bernoulli, categorical or normal);
''')

# Create a new cgpm from the altered metadata

def change_concentration_hyperparameters(cgpm, value):
"""Generate new cgpm from input cgpm, and alter its crp
concentration hyperparameter to extreme values.
"""

# Alter metadata.alpha to extreme values according to data type
if isinstance(cgpm, View):
elif isinstance(cgpm, State):
new_metadata['view_alphas'] = [(a[0], value) for a in old_view_alphas]

def restrict_evidence_to_query(query, evidence):
"""Return subset of evidence whose rows are also present in query."""
return {i: j for i, j in evidence.iteritems() if i in query.keys()}

_gen_data = {
'bernoulli'         : _gen_bernoulli_data,
'beta'              : _gen_beta_data,
'categorical'       : _gen_categorical_data,
'exponential'       : _gen_exponential_data,
'geometric'         : _gen_geometric_data,
'lognormal'         : _gen_lognormal_data,
'normal'            : _gen_normal_data,
'normal_trunc'      : _gen_normal_trunc_data,
'poisson'           : _gen_poisson_data,
'vonmises'          : _gen_vonmises_data,
}
```