import os
import  numpy as np
import  matplotlib.pyplot as plt

from sklearn.metrics import jaccard_similarity_score

from data_reader import  get_patches_dir
from model import  get_unet

GOLD_IMG_4Classes_3 = ['fieldborders_7_4109150_merged', 'fieldborders_28_4109150_merged', 'wsb_9_4109150_merged']
GOLD_IMG_4Classes_4 = ['waterways_2285_4109150_merged', 'waterways_2633_4109150_merged',
                       'waterways_1285_4109150_merged', 'fieldborders_19_4109150_merged']
AMT_SMALL_VAL = 100


def predict_from_val(model, amt_pred, trs, config, x):
    prd = np.zeros((amt_pred, config.ISZ, config.ISZ, config.NUM_CLASSES)).astype(np.float32)
    tmp = model.predict(x, batch_size=1)
    for i in range(amt_pred):
        prd_i = tmp[i]
        for c in range(config.NUM_CLASSES):
            prd[i,:,:,c] = prd_i[:,:,c] > trs[c]
    return prd

def display_pred(pred_res, true_masks, config, modelName, amt_pred, trs, min_pred_sum, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    trs_str = "trs_none"
    if trs != None:
        trs_str = '_'.join([str(x) for x in trs])
        trs_str = "trs_" + trs_str

    #print("Saving predictions when np.sum(pred) >", min_pred_sum)
    nothing_saved = True
    for p in range(amt_pred):
        for i in range(config.NUM_CLASSES):
            pred = pred_res[p,:, :, i]
            true_mask = true_masks[p, :, :, i]

            sum_pred = np.sum(pred)
            sum_mask = np.sum(true_mask)
            if sum_pred > min_pred_sum :
                    #and sum_mask > min_pred_sum:

                jk = jaccard_similarity_score(true_mask, pred)
                #print("Calc jaccard", jk)
                fn = os.path.join(output_folder,"{4}{0}_p{1}_cl{2}_{3}.png".format(modelName, p, i, trs_str, jk))
                #print("Saving  predictions with np.sum {0} to  {1}".format(sum, fn))
                plt.imsave(fn, pred, cmap='hot')

                fn_tr= os.path.join(output_folder,"{4}{0}_p{1}_TRUE_cl{2}_{3}.png".format(modelName, p, i, trs_str, jk))
                plt.imsave(fn_tr, true_mask, cmap='hot')

                nothing_saved = False

    if (nothing_saved):
        print("All predictions did not satisfy: sum_pred > min_pred_sum, nothing saved. Min_pred_sum:", min_pred_sum)


def check_predict_folder(model, model_name, val_dir,  predict_dir,  config, loss_mode, amt_pred, verbose, imageNames =[]):
    x_val, y_val = get_patches_dir(val_dir, config, shuffleOn=False, amt=amt_pred, verbose=verbose,
                                imageNames=imageNames)
    for i in range(1,10):
        tr = []
        for j in range(config.NUM_CLASSES):
            tr.append(i / 10.0)
        check_predict_model(model, model_name, config, predic_cnt=amt_pred,
                      trs = tr, x=x_val, y=y_val,  min_pred_sum = 10,
                      output_folder= predict_dir)

def check_predict_small_test(model, model_name, predict_dir, config, loss_mode):
    verbose = False
    amt_pred = config.AMT_SMALL_VAL
    predict_dir = predict_dir + "_small_test"
    val_dir = config.SMAL_VAL_DIR

    return check_predict_folder(model, model_name, val_dir, predict_dir, config, loss_mode, amt_pred, verbose)


def check_predict_gold(model, model_name, predict_dir, config, loss_mode):
    amt_pred = 3
    verbose = False
    imageNames = GOLD_IMG_4Classes_4
    predict_dir = predict_dir + "_gold"
    val_dir = config.VAL_DIR
    return check_predict_folder(model, model_name, val_dir, predict_dir, config, loss_mode, amt_pred, verbose, imageNames= imageNames)

def check_predict_model(model, model_name, config, predic_cnt, trs, x, y, min_pred_sum, output_folder):
    pred_res = predict_from_val(model, predic_cnt, trs, config, x)
    display_pred(pred_res, y, config, modelName=model_name, amt_pred=predic_cnt, trs=trs,
                 min_pred_sum=min_pred_sum,
                 output_folder=output_folder)


def check_predict(model_name, weights_folder, config, loss_mode, predic_cnt, trs, x, y, min_pred_sum, output_folder):
    model = get_unet(config, loss_mode)
    model.load_weights(os.path.join(weights_folder, model_name))
    check_predict_model(model, model_name, config, predic_cnt, trs, x, y, min_pred_sum, output_folder)