import copy

from csrank.theano_util import normalize

try:
import pymc3 as pm
from pymc3 import Discrete
from pymc3.distributions.dist_math import bound
from pymc3.variational.callbacks import CheckParametersConvergence
except ImportError:
from csrank.util import MissingExtraError

raise MissingExtraError("pymc3", "probabilistic")

try:
import theano
from theano import tensor as tt
except ImportError:
from csrank.util import MissingExtraError

raise MissingExtraError("theano", "probabilistic")

def categorical_crossentropy(p, y_true):
return -tt.nnet.categorical_crossentropy(p, y_true)

def binary_crossentropy(p, y_true):
axis = 1 if p.ndim > 1 else 0
return -tt.nnet.binary_crossentropy(p, y_true).mean(axis=axis)

def categorical_hinge(p, y_true):
pos = tt.sum(y_true * p, axis=-1)
neg = tt.max((1.0 - y_true) * p, axis=-1)
return -tt.maximum(0.0, neg - pos + 1.0)

likelihood_dict = {
"categorical_crossentropy": categorical_crossentropy,
"binary_crossentropy": binary_crossentropy,
"categorical_hinge": categorical_hinge,
}

class LogLikelihood(Discrete):
"""
Categorical log-likelihood.

The most general discrete distribution.

.. math:: f(x \\mid p) = p_x

========  ===================================
Support   :math:x \\in \\{0, 1, \\ldots, |p|-1\\}
========  ===================================

Parameters
----------
p : array of floats
p > 0 and the elements of p must sum to 1. They will be automatically
rescaled otherwise.
"""

def __init__(self, loss_func, p, *args, **kwargs):
super(LogLikelihood, self).__init__(*args, **kwargs)
if loss_func is None:
loss_func = categorical_crossentropy
self.loss_func = loss_func
try:
self.k = tt.shape(p)[-1].tag.test_value
except AttributeError:
self.k = tt.shape(p)[-1]
self.p = tt.as_tensor_variable(p)
self.mode = tt.argmax(p)

def random(self, **kwargs):
return NotImplemented

def logp(self, value):
p = self.p
k = self.k
a = self.loss_func(p, value)
p = normalize(p)

value_k = tt.argmax(value, axis=1)
return bound(a, value_k >= 0, value_k <= (k - 1), sum_to1)

def create_weight_dictionary(model_args, shapes):
weights_dict = dict()
for key, value in model_args.items():
prior, params = copy.deepcopy(value)
for k in params.keys():
if isinstance(params[k], tuple):
params[k][1]["name"] = "{}_{}".format(key, k)
params[k] = params[k][0](**params[k][1])
params["name"] = key
params["shape"] = shapes[key]
weights_dict[key] = prior(**params)
return weights_dict

def fit_pymc3_model(self, sampler, draws, tune, vi_params, **kwargs):
callbacks = vi_params.get("callbacks", [])
for i, c in enumerate(callbacks):
if isinstance(c, CheckParametersConvergence):
params = c.__dict__
params.pop("_diff")
params.pop("prev")
params.pop("ord")
params["diff"] = "absolute"
callbacks[i] = CheckParametersConvergence(**params)
if sampler == "variational":
with self.model:
try:
self.trace = pm.sample(chains=2, cores=8, tune=5, draws=5)
vi_params["start"] = self.trace[-1]
self.trace_vi = pm.fit(**vi_params)
self.trace = self.trace_vi.sample(draws=draws)
except Exception as e:
if hasattr(e, "message"):
message = e.message
else:
message = e
self.logger.error(message)
self.trace_vi = None
if self.trace_vi is None and self.trace is None:
with self.model:
self.logger.info(
"Error in vi ADVI sampler using Metropolis sampler with draws {}".format(
draws
)
)
self.trace = pm.sample(
chains=1, cores=4, tune=20, draws=20, step=pm.NUTS()
)
elif sampler == "metropolis":
with self.model:
start = pm.find_MAP()
self.trace = pm.sample(
chains=2,
cores=8,
tune=tune,
draws=draws,
**kwargs,
step=pm.Metropolis(),
start=start,
)
else:
with self.model:
self.trace = pm.sample(
chains=2, cores=8, tune=tune, draws=draws, **kwargs, step=pm.NUTS()
)