# -*- coding: utf-8 -*-
"""
@author: Satoshi Hara
"""

import sys
import os
sys.path.append(os.path.abspath('./'))
sys.path.append(os.path.abspath('./baselines/'))
sys.path.append(os.path.abspath('../'))

import numpy as np
import paper_sub
from RForest import RForest

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colorbar as colorbar

def plotTZ(filename=None):
    t = np.linspace(0, 1, 101)
    z = 0.25 + 0.5 / (1 + np.exp(- 20 * (t - 0.5))) + 0.05 * np.cos(t * 2 * np.pi)
    cmap = cm.get_cmap('cool')
    fig, (ax1, ax2) = plt.subplots(1, 2, gridspec_kw = {'width_ratios':[19, 1]})
    poly1 = [[0, 0]]
    poly1.extend([[t[i], z[i]] for i in range(t.size)])
    poly1.extend([[1, 0], [0, 0]])
    poly2 = [[0, 1]]
    poly2.extend([[t[i], z[i]] for i in range(t.size)])
    poly2.extend([[1, 1], [0, 1]])
    poly1 = plt.Polygon(poly1,fc=cmap(0.0))
    poly2 = plt.Polygon(poly2,fc=cmap(1.0))
    ax1.add_patch(poly1)
    ax1.add_patch(poly2)
    ax1.set_xlabel('x1', size=22)
    ax1.set_ylabel('x2', size=22)
    ax1.set_title('True Data', size=28)
    colorbar.ColorbarBase(ax2, cmap=cmap, format='%.1f')
    ax2.set_ylabel('Output y', size=22)
    plt.show()
    if not filename is None:
        plt.savefig(filename, format="pdf", bbox_inches="tight")
        plt.close()
    
def plotForest(filename=None):
    forest = RForest(modeltype='classification')
    forest.fit('./result/result_synthetic2/forest/')
    X = np.c_[np.kron(np.linspace(0, 1, 201), np.ones(201)), np.kron(np.ones(201), np.linspace(0, 1, 201))]
    forest.plot(X, 0, 1, box0=np.array([[0.0, 0.0], [1.0, 1.0]]), filename=filename)
    
if __name__ == "__main__":

    # setting
    prefix = 'synthetic2'
    seed = 0
    num = 1000
    dim = 2
    
    # data - boundary
    b = 0.9
    t = np.linspace(0, 1, 101)
    z = 0.25 + 0.5 / (1 + np.exp(- 20 * (t - 0.5))) + 0.05 * np.cos(t * 2 * np.pi)
    
    # data - train
    np.random.seed(seed)
    Xtr = np.random.rand(num, dim)
    ytr = np.zeros(num)
    ytr = (Xtr[:, 1] > 0.25 + 0.5 / (1 + np.exp(- 20 * (Xtr[:, 0] - 0.5))) + 0.05 * np.cos(Xtr[:, 0] * 2 * np.pi))
    ytr = np.logical_xor(ytr, np.random.rand(num) > b)
    
    # data - test
    Xte = np.random.rand(num, dim)
    yte = np.zeros(num)
    yte = (Xte[:, 1] > 0.25 + 0.5 / (1 + np.exp(- 20 * (Xte[:, 0] - 0.5))) + 0.05 * np.cos(Xte[:, 0] * 2 * np.pi))
    yte = np.logical_xor(yte, np.random.rand(num) > b)
    
    # save
    dirname = './result/result_%s' % (prefix,)
    if not os.path.exists('./result/'):
        os.mkdir('./result/')
    if not os.path.exists(dirname):
        os.mkdir(dirname)
    trfile = '%s/%s_train.csv' % (dirname, prefix)
    tefile = '%s/%s_test.csv' % (dirname, prefix)
    np.savetxt(trfile, np.c_[Xtr, ytr], delimiter=',')
    np.savetxt(tefile, np.c_[Xte, yte], delimiter=',')
    
    # demo_R
    Kmax = 10
    restart = 20
    treenum = 100
    M = range(1, 11)
    #paper_sub.run(prefix, Kmax, restart, treenum=treenum, modeltype='classification', plot=True, plot_line=[[t, z]])
    paper_sub.run(prefix, Kmax, restart, treenum=treenum, modeltype='classification', plot=True, plot_line=[[t, z]], M=M, compare=True)
    
    # plot
    plotTZ('%s/%s_true.pdf' % (dirname, prefix))
    plotForest('%s/%s_rf_tree05_seed00.pdf' % (dirname, prefix))