"""
Script that loads random forest models trained on the sider and tox21 datasets,
predicts on sweetlead, creates covariance matrix

@Author Aneesh Pappu
"""
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import os
import sys
import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.ensemble import RandomForestClassifier
from deepchem.models.multitask import SingletaskToMultitask
from deepchem import metrics
from deepchem.metrics import Metric
from deepchem.models.sklearn_models import SklearnModel

tox_tasks, (tox_train, tox_valid,
            tox_test), tox_transformers = dc.molnet.load_tox21()

classification_metric = Metric(
    metrics.roc_auc_score, np.mean, mode="classification")


def model_builder(model_dir):
  sklearn_model = RandomForestClassifier(
      class_weight="balanced", n_estimators=500, n_jobs=-1)
  return dc.models.SklearnModel(sklearn_model, model_dir)


print(tox_train.get_task_names())
print(tox_tasks)
tox_model = SingletaskToMultitask(tox_tasks, model_builder)
tox_model.fit(tox_train)

# Load sider models now

sider_tasks, (
    sider_train, sider_valid,
    sider_test), sider_transformers = dc.molnet.load_sider(split="random")

sider_model = SingletaskToMultitask(sider_tasks, model_builder)
sider_model.fit(sider_train)

# Load sweetlead dataset now. Pass in dataset object and appropriate
# transformers to predict functions

sweet_tasks, (sweet_dataset, _, _), sweet_transformers = dc.molnet.load_sweet()

sider_predictions = sider_model.predict(sweet_dataset, sweet_transformers)

tox_predictions = tox_model.predict(sweet_dataset, sweet_transformers)

sider_dimensions = sider_predictions.shape[1]
tox_dimensions = tox_predictions.shape[1]

confusion_matrix = np.zeros(shape=(tox_dimensions, sider_dimensions))
for i in range(tox_predictions.shape[0]):
  nonzero_tox = np.nonzero(tox_predictions[i, :])
  nonzero_sider = np.nonzero(sider_predictions[i, :])
  for j in nonzero_tox[0]:
    for k in nonzero_sider[0]:
      confusion_matrix[j, k] += 1

df = pd.DataFrame(confusion_matrix)

df.to_csv("./tox_sider_matrix.csv")