"""
This script was generated by the train.py script in this repository:
https://github.com/ecohealthalliance/geoname-annotator-training
"""
import numpy as np
from numpy import array, int32


HIGH_CONFIDENCE_THRESHOLD = 0.5
GEONAME_SCORE_THRESHOLD = 0.13
base_classifier =\
{
    'penalty': 'l1',
    'dual': False,
    'tol': 0.0001,
    'C': 0.1,
    'fit_intercept': True,
    'intercept_scaling': 1,
    'class_weight': None,
    'random_state': None,
    'solver': 'liblinear',
    'max_iter': 100,
    'multi_class': 'auto',
    'verbose': 0,
    'warm_start': False,
    'n_jobs': None,
    'l1_ratio': None,
    'classes_': array([False,  True]),
    'coef_': array([[
        # log_population
        0.2833264880250032,
        # name_count
        0.496850674834566,
        # names_used
        0.6182820633848923,
        # exact_name_match
        1.2348109901150883,
        # multiple_spans
        1.0027275126871056,
        # span_length
        0.1556028330182338,
        # all_acronyms
        -1.723855098135833,
        # cannonical_name_used
        2.4448401433682614,
        # loc_NE_portion
        1.100324515118662,
        # other_NE_portion
        -0.006789540638115011,
        # noun_portion
        1.7846512186329173,
        # num_tokens
        0.4449857820783286,
        # med_token_prob
        -0.35487604034835174,
        # exact_alternatives
        -0.7357808282729202,
        # PPL_feature_code
        -3.313119817081297,
        # ADM_feature_code
        -3.8508572602598976,
        # PCL_feature_code
        0.002625732530454115,
        # other_feature_code
        -2.5725460514389598,
        # first_order
        0.8582848733971942,
        # combined_span
        0.7845557182369852,
        # close_locations
        0.0,
        # very_close_locations
        0.0,
        # base_score
        0.0,
        # base_score_margin
        0.0,
        # contained_locations
        0.0,
        # containing_locations
        0.0,
    ]]),
    'intercept_': array([-14.6902496]),
    'n_iter_': array([44], dtype=int32),
}

contextual_classifier =\
{
    'penalty': 'l1',
    'dual': False,
    'tol': 0.0001,
    'C': 0.1,
    'fit_intercept': True,
    'intercept_scaling': 1,
    'class_weight': None,
    'random_state': None,
    'solver': 'liblinear',
    'max_iter': 100,
    'multi_class': 'auto',
    'verbose': 0,
    'warm_start': False,
    'n_jobs': None,
    'l1_ratio': None,
    'classes_': array([False,  True]),
    'coef_': array([[
        # log_population
        0.3787203315925731,
        # name_count
        0.47246832816657763,
        # names_used
        1.0765607603242244,
        # exact_name_match
        1.4705218728593559,
        # multiple_spans
        1.3801379355279673,
        # span_length
        0.21060539648691756,
        # all_acronyms
        -2.5491642087123516,
        # cannonical_name_used
        2.877038521477874,
        # loc_NE_portion
        1.6424801016350434,
        # other_NE_portion
        -0.38562595379247006,
        # noun_portion
        2.002630501746275,
        # num_tokens
        0.4152636087418877,
        # med_token_prob
        -0.32906537630371446,
        # exact_alternatives
        -0.8984979885089859,
        # PPL_feature_code
        -4.534782739767053,
        # ADM_feature_code
        -5.510120071836727,
        # PCL_feature_code
        0.003626341435930206,
        # other_feature_code
        -3.2323872260723783,
        # first_order
        1.451007214897749,
        # combined_span
        1.480736480667228,
        # close_locations
        0.1719448315096514,
        # very_close_locations
        -0.07886456859233734,
        # base_score
        -5.393131594700634,
        # base_score_margin
        3.119070808829057,
        # contained_locations
        0.1825358655341298,
        # containing_locations
        0.962090412821035,
    ]]),
    'intercept_': array([-14.57426316]),
    'n_iter_': array([52], dtype=int32),
}

# Logistic regression code from scipy
def predict_proba(X, classifier):
    """Probability estimation for OvR logistic regression.
    Positive class probabilities are computed as
    1. / (1. + np.exp(-classifier.decision_function(X)));
    multiclass is handled by normalizing that over all classes.
    """
    prob = np.dot(X, classifier['coef_'].T) + classifier['intercept_']
    prob = prob.ravel() if prob.shape[1] == 1 else prob
    prob *= -1
    np.exp(prob, prob)
    prob += 1
    np.reciprocal(prob, prob)
    if prob.ndim == 1:
        return np.vstack([1 - prob, prob]).T
    else:
        # OvR normalization, like LibLinear's predict_probability
        prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
        return prob


def predict_proba_base(X):
    return predict_proba(X, base_classifier)


def predict_proba_contextual(X):
    return predict_proba(X, contextual_classifier)