import cv2
import random
import numpy as np
import os
import argparse
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from FCN8S import dice_coef
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

basePath="C:\\Users\Administrator\Desktop\Project\\";
TEST_SET = ['1.png']

image_size = 32

classes=[0.0,1.0,2.0,3.0,4.0,15.0] 
  
labelencoder = LabelEncoder()  
labelencoder.fit(classes) 

def args_parse():
# construct the argument parse and parse the arguments
    ap = argparse.ArgumentParser()
    ap.add_argument("-m", "--model", required=False,default="FCN.h5",
        help="path to trained model model")
    ap.add_argument("-s", "--stride", required=False,
        help="crop slide stride", type=int, default=image_size)
    args = vars(ap.parse_args())    
    return args

    
def predict(args):
    # load the trained convolutional neural network
    print("载入网络权重中……")
    model = load_model(args["model"],custom_objects={'dice_coef': dice_coef})
    stride = args['stride']
    print("进行预测分割拼图中……")
    for n in range(len(TEST_SET)):
        path = TEST_SET[n]
        #load the image
        image = cv2.imread(basePath+'train\\' + path)
        h,w,_ = image.shape
        padding_h = (h//stride + 1) * stride 
        padding_w = (w//stride + 1) * stride
        padding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)
        padding_img[0:h,0:w,:] = image[:,:,:]
        padding_img = padding_img.astype("float") / 255.0
        padding_img = img_to_array(padding_img)
        mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)
        for i in range(padding_h//stride):
            for j in range(padding_w//stride):
                crop = padding_img[i*stride:i*stride+image_size,j*stride:j*stride+image_size,:3]
                ch,cw,_ = crop.shape
                #print(ch,cw,_)
                if ch != 32 or cw != 32:
                    print('尺寸不正确,请检查!')
                    continue
                crop = np.expand_dims(crop, axis=0) 
                pred = model.predict(crop,verbose=2)
                pred=np.argmax(pred,axis=3)
                pred=pred.flatten()
                pred = labelencoder.inverse_transform(pred)
                pred = pred.reshape((32,32)).astype(np.uint8)
                mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]

        cv2.imwrite(basePath+'predict/'+path,mask_whole[0:h,0:w])
        
    

    
if __name__ == '__main__':
    A=time.time()
    args = args_parse()
    predict(args)
    B=time.time()
    print("运行时长:%.1f" % float(B-A)+"s")