# -*- coding:utf-8 -*-
"""
#====#====#====#====
# Project Name:     U-net 
# File Name:        unet-TF-withBatchNormal 
# Date:             2/17/18 8:18 PM 
# Using IDE:        PyCharm Community Edition  
# From HomePage:    https://github.com/DuFanXin/U-net
# Author:           DuFanXin 
# BlogPage:         http://blog.csdn.net/qq_30239975  
# E-mail:           18672969179@163.com
# Copyright (c) 2018, All Rights Reserved.
#====#====#====#==== 
"""
import tensorflow as tf
import argparse
import os

TRAIN_SET_NAME = 'train_set.tfrecords'
VALIDATION_SET_NAME = 'validation_set.tfrecords'
TEST_SET_NAME = 'test_set.tfrecords'
ORIGIN_PREDICT_DIRECTORY = '../data_set/test'
INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL = 512, 512, 1
OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT, OUTPUT_IMG_CHANNEL = 512, 512, 1
TRAIN_SET_SIZE = 8
EPOCH_NUM = 1
TRAIN_BATCH_SIZE = 1
VALIDATION_BATCH_SIZE = 1
TEST_SET_SIZE = 30
TEST_BATCH_SIZE = 1
PREDICT_BATCH_SIZE = 1
PREDICT_SAVED_DIRECTORY = '../data_set/predictions'
EPS = 10e-5
FLAGS = None
CLASS_NUM = 2
CHECK_POINT_PATH = '../data_set/saved_models/train_5th/model.ckpt'


def calculate_unet_input_and_output(bottom=0):
	# 从最底层右边开始计算网络的输入输出的图片大小
	y, z = bottom + 2, bottom * 2
	y, z = y + 2, z - 2
	y, z = y * 2, z - 2
	y, z = y + 2, z * 2
	y, z = y + 2, z - 2
	y, z = y * 2, z - 2
	y, z = y + 2, z * 2
	y, z = y + 2, z - 2
	y, z = y * 2, z - 2
	y, z = y + 2, z * 2
	y, z = y + 2, z - 2
	y, z = y * 2, z - 2
	y += 4
	print(y)
	print(z)


def convert_from_color_segmentation(image_3d=None):
	import numpy as np
	from skimage import img_as_ubyte
	import warnings
	patterns = {
		(0, 0, 0):          0,      # 背景
		(128, 0, 0):        1,      # 飞机
		(0, 128, 0):        2,      # 自行车
		(128, 128, 0):      3,      # 鸟
		(0, 0, 128):        4,      # 船
		(128, 0, 128):      5,      # 瓶子
		(0, 128, 128):      6,      # 大巴
		(128, 128, 128):    7,      # 车
		(64, 0, 0):         8,      # 猫
		(192, 0, 0):        9,      # 椅子
		(64, 128, 0):       10,     # 牛
		(192, 128, 0):      11,     # 餐桌
		(64, 0, 128):       12,     # 狗
		(192, 0, 128):      13,     # 马
		(64, 128, 128):     14,     # 摩托车
		(192, 128, 128):    15,     # 人
		(0, 64, 0):         16,     # 盆栽
		(128, 64, 0):       17,     # 羊
		(0, 192, 0):        18,     # 沙发
		(128, 192, 0):      19,     # 火车
		(0, 64, 128):       20      # 显示器
	}
	with warnings.catch_warnings():
		warnings.simplefilter("ignore")
		image_3d = img_as_ubyte(image_3d)
	image_2d = np.zeros((image_3d.shape[0], image_3d.shape[1]), dtype=np.uint8)
	for pattern, index in patterns.items():
		# print(pattern)
		# print(index)
		m = (image_3d == np.array(pattern).reshape(1, 1, 3)).all(axis=2)
		# print(m)
		image_2d[m] = index

	return image_2d


def write_img_to_tfrecords():
	import cv2
	# from skimage import io, transform
	import glob
	import numpy as np
	train_set_size = 28
	validation_set_size = 2
	path = glob.glob(os.path.join('/home/dufanxin/PycharmProjects/Image-Augmentor/inputdata/image', 'combine0.png'))
	print(len(path))
	train_set_writer = tf.python_io.TFRecordWriter(os.path.join(FLAGS.data_dir, TRAIN_SET_NAME))  # 要生成的文件
	validation_set_writer = tf.python_io.TFRecordWriter(os.path.join(FLAGS.data_dir, VALIDATION_SET_NAME))
	# test_set_writer = tf.python_io.TFRecordWriter(os.path.join(FLAGS.data_dir, 'test_set.tfrecords'))  # 要生成的文件
	train_path = path[:10000]
	validation_path = path[10000:]
	# print(len(path))

	for index, file_path in enumerate(train_path):
		train_image = cv2.imread(file_path)
		# train_image = cv2.resize(src=train_image, dsize=(INPUT_IMG_WIDE, INPUT_IMG_HEIGHT))
		sample_image = np.asarray(a=train_image[:, :, 0], dtype=np.uint8)
		sample_image = cv2.resize(src=sample_image, dsize=(INPUT_IMG_WIDE, INPUT_IMG_HEIGHT))
		label_image = np.asarray(a=train_image[:, :, 2], dtype=np.uint8)
		label_image = cv2.resize(src=label_image, dsize=(OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT))
		label_image[label_image <= 100] = 0
		label_image[label_image > 100] = 1
		# train_image = io.imread(file_path)
		# train_image = transform.resize(train_image, (INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
		# sample_image = train_image[:, :, 0]
		# label_image = train_image[:, :, 2]
		# label_image[label_image < 100] = 0
		# label_image[label_image > 100] = 10
		example = tf.train.Example(features=tf.train.Features(feature={
			'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_image.tobytes()])),
			'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample_image.tobytes()]))
		}))  # example对象对label和image数据进行封装
		train_set_writer.write(example.SerializeToString())  # 序列化为字符串
		if index % 100 == 0:
			print('Done train_set writing %.2f%%' % (index / train_set_size * 100))
	train_set_writer.close()
	print("Done train_set writing")

	for index, file_path in enumerate(validation_path):
		validation_image = cv2.imread(file_path)
		# validation_image = cv2.resize(src=validation_image, dsize=(INPUT_IMG_WIDE, INPUT_IMG_HEIGHT))
		sample_image = np.asarray(a=validation_image[:, :, 0], dtype=np.uint8)
		sample_image = cv2.resize(src=sample_image, dsize=(INPUT_IMG_WIDE, INPUT_IMG_HEIGHT))
		label_image = np.asarray(validation_image[:, :, 2], dtype=np.uint8)
		label_image = cv2.resize(src=label_image, dsize=(OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT))
		label_image[label_image <= 100] = 0
		label_image[label_image > 100] = 10
		# validation_image = io.imread(file_path)
		# validation_image = transform.resize(validation_image, (INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
		# sample_image = validation_image[:, :, 0]
		# label_image = validation_image[:, :, 2]
		# label_image[label_image < 100] = 0
		# label_image[label_image > 100] = 10
		example = tf.train.Example(features=tf.train.Features(feature={
			'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_image.tobytes()])),
			'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample_image.tobytes()]))
		}))  # example对象对label和image数据进行封装
		validation_set_writer.write(example.SerializeToString())  # 序列化为字符串
		if index % 100 == 0:
			print('Done validation_set writing %.2f%%' % (index / validation_set_size * 100))
	validation_set_writer.close()
	print("Done validation_set writing")


def read_check_tfrecords():
	import cv2
	train_file_path = os.path.join(FLAGS.data_dir, TRAIN_SET_NAME)
	train_image_filename_queue = tf.train.string_input_producer(
		string_tensor=tf.train.match_filenames_once(train_file_path), num_epochs=1, shuffle=True)
	train_images, train_labels = read_image(train_image_filename_queue)
	# one_hot_labels = tf.to_float(tf.one_hot(indices=train_labels, depth=CLASS_NUM))
	with tf.Session() as sess:  # 开始一个会话
		sess.run(tf.global_variables_initializer())
		sess.run(tf.local_variables_initializer())
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(coord=coord)
		example, label = sess.run([train_images, train_labels])
		cv2.imshow('image', example)
		cv2.imshow('lael', label * 100)
		cv2.waitKey(0)
		# print(sess.run(one_hot_labels))
		coord.request_stop()
		coord.join(threads)
	print("Done reading and checking")


def read_image(file_queue):
	reader = tf.TFRecordReader()
	# key, value = reader.read(file_queue)
	_, serialized_example = reader.read(file_queue)
	features = tf.parse_single_example(
		serialized_example,
		features={
			'label': tf.FixedLenFeature([], tf.string),
			'image_raw': tf.FixedLenFeature([], tf.string)
			})

	image = tf.decode_raw(features['image_raw'], tf.uint8)
	# print('image ' + str(image))
	image = tf.reshape(image, [INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL])
	# image = tf.image.convert_image_dtype(image, dtype=tf.float32)
	# image = tf.image.resize_images(image, (IMG_HEIGHT, IMG_WIDE))
	# image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

	label = tf.decode_raw(features['label'], tf.uint8)
	# label = tf.cast(label, tf.int64)
	label = tf.reshape(label, [OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT])
	# label = tf.decode_raw(features['image_raw'], tf.uint8)
	# print(label)
	# label = tf.reshape(label, shape=[1, 4])
	return image, label


def read_image_batch(file_queue, batch_size):
	img, label = read_image(file_queue)
	min_after_dequeue = 2000
	capacity = 4000
	# image_batch, label_batch = tf.train.batch([img, label], batch_size=batch_size, capacity=capacity, num_threads=10)
	image_batch, label_batch = tf.train.shuffle_batch(
		tensors=[img, label], batch_size=batch_size,
		capacity=capacity, min_after_dequeue=min_after_dequeue)
	# one_hot_labels = tf.to_float(tf.one_hot(indices=label_batch, depth=CLASS_NUM))
	one_hot_labels = tf.reshape(label_batch, [batch_size, OUTPUT_IMG_HEIGHT, OUTPUT_IMG_WIDE])
	return image_batch, one_hot_labels


class Unet:

	def __init__(self):
		print('New U-net Network')
		self.input_image = None
		self.input_label = None
		self.cast_image = None
		self.cast_label = None
		self.keep_prob = None
		self.lamb = None
		self.result_expand = None
		self.is_traing = None
		self.loss, self.loss_mean, self.loss_all, self.train_step = [None] * 4
		self.prediction, self.correct_prediction, self.accuracy = [None] * 3
		self.result_conv = {}
		self.result_relu = {}
		self.result_maxpool = {}
		self.result_from_contract_layer = {}
		self.w = {}
		# self.b = {}

	def init_w(self, shape, name):
		with tf.name_scope('init_w'):
			stddev = tf.sqrt(x=2 / (shape[0] * shape[1] * shape[2]))
			# stddev = 0.01
			w = tf.Variable(initial_value=tf.truncated_normal(shape=shape, stddev=stddev, dtype=tf.float32), name=name)
			tf.add_to_collection(name='loss', value=tf.contrib.layers.l2_regularizer(self.lamb)(w))
			return w

	@staticmethod
	def init_b(shape, name):
		with tf.name_scope('init_b'):
			return tf.Variable(initial_value=tf.random_normal(shape=shape, dtype=tf.float32), name=name)

	@staticmethod
	def batch_norm(x, is_training, eps=EPS, decay=0.9, affine=True, name='BatchNorm2d'):
		from tensorflow.python.training.moving_averages import assign_moving_average

		with tf.variable_scope(name):
			params_shape = x.shape[-1:]
			moving_mean = tf.get_variable(name='mean', shape=params_shape, initializer=tf.zeros_initializer, trainable=False)
			moving_var = tf.get_variable(name='variance', shape=params_shape, initializer=tf.ones_initializer, trainable=False)

			def mean_var_with_update():
				mean_this_batch, variance_this_batch = tf.nn.moments(x, list(range(len(x.shape) - 1)), name='moments')
				with tf.control_dependencies([
					assign_moving_average(moving_mean, mean_this_batch, decay),
					assign_moving_average(moving_var, variance_this_batch, decay)
				]):
					return tf.identity(mean_this_batch), tf.identity(variance_this_batch)

			mean, variance = tf.cond(is_training, mean_var_with_update, lambda: (moving_mean, moving_var))
			if affine:  # 如果要用beta和gamma进行放缩
				beta = tf.get_variable('beta', params_shape, initializer=tf.zeros_initializer)
				gamma = tf.get_variable('gamma', params_shape, initializer=tf.ones_initializer)
				normed = tf.nn.batch_normalization(x, mean=mean, variance=variance, offset=beta, scale=gamma, variance_epsilon=eps)
			else:
				normed = tf.nn.batch_normalization(x, mean=mean, variance=variance, offset=None, scale=None,  variance_epsilon=eps)
			return normed

	@staticmethod
	def copy_and_crop_and_merge(result_from_contract_layer, result_from_upsampling):
		# result_from_contract_layer_shape = tf.shape(result_from_contract_layer)
		# result_from_upsampling_shape = tf.shape(result_from_upsampling)
		# result_from_contract_layer_crop = \
		# 	tf.slice(
		# 		input_=result_from_contract_layer,
		# 		begin=[
		# 			0,
		# 			(result_from_contract_layer_shape[1] - result_from_upsampling_shape[1]) // 2,
		# 			(result_from_contract_layer_shape[2] - result_from_upsampling_shape[2]) // 2,
		# 			0
		# 		],
		# 		size=[
		# 			result_from_upsampling_shape[0],
		# 			result_from_upsampling_shape[1],
		# 			result_from_upsampling_shape[2],
		# 			result_from_upsampling_shape[3]
		# 		]
		# 	)
		result_from_contract_layer_crop = result_from_contract_layer
		return tf.concat(values=[result_from_contract_layer_crop, result_from_upsampling], axis=-1)

	@staticmethod
	def conv(w, bias,):
		return

	def set_up_unet(self, batch_size):
		# input
		with tf.name_scope('input'):
			# learning_rate = tf.train.exponential_decay()
			self.input_image = tf.placeholder(
				dtype=tf.float32, shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL], name='input_images'
			)
			# self.cast_image = tf.reshape(
			# 	tensor=self.input_image,
			# 	shape=[batch_size, INPUT_IMG_WIDE, INPUT_IMG_WIDE, INPUT_IMG_CHANNEL]
			# )

			# For softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			# using one-hot
			# self.input_label = tf.placeholder(
			# 	dtype=tf.uint8, shape=[OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
			# )
			# self.cast_label = tf.reshape(
			# 	tensor=self.input_label,
			# 	shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_HEIGHT]
			# )

			# For sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			# not using one-hot coding
			self.input_label = tf.placeholder(
				dtype=tf.int32, shape=[batch_size, OUTPUT_IMG_WIDE, OUTPUT_IMG_WIDE], name='input_labels'
			)
			self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
			self.lamb = tf.placeholder(dtype=tf.float32, name='lambda')
			self.is_traing = tf.placeholder(dtype=tf.bool, name='is_traing')
			normed_batch = self.batch_norm(x=self.input_image, is_training=self.is_traing, name='input')

		# layer 1
		with tf.name_scope('layer_1'):
			# conv_1
			self.w[1] = self.init_w(shape=[3, 3, INPUT_IMG_CHANNEL, 64], name='w_1')
			# self.b[1] = self.init_b(shape=[64], name='b_1')
			result_conv_1 = tf.nn.conv2d(
				input=normed_batch, filter=self.w[1], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_1_conv_1')
			result_relu_1 = tf.nn.relu(normed_batch, name='relu_1')

			# conv_2
			self.w[2] = self.init_w(shape=[3, 3, 64, 64], name='w_2')
			# self.b[2] = self.init_b(shape=[64], name='b_2')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[2], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_1_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')
			self.result_from_contract_layer[1] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 2
		with tf.name_scope('layer_2'):
			# conv_1
			self.w[3] = self.init_w(shape=[3, 3, 64, 128], name='w_3')
			# self.b[3] = self.init_b(shape=[128], name='b_3')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[3], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_2_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[4] = self.init_w(shape=[3, 3, 128, 128], name='w_4')
			# self.b[4] = self.init_b(shape=[128], name='b_4')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[4], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_2_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')
			self.result_from_contract_layer[2] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 3
		with tf.name_scope('layer_3'):
			# conv_1
			self.w[5] = self.init_w(shape=[3, 3, 128, 256], name='w_5')
			# self.b[5] = self.init_b(shape=[256], name='b_5')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[5], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_3_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[6] = self.init_w(shape=[3, 3, 256, 256], name='w_6')
			# self.b[6] = self.init_b(shape=[256], name='b_6')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[6], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_3_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')
			self.result_from_contract_layer[3] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 4
		with tf.name_scope('layer_4'):
			# conv_1
			self.w[7] = self.init_w(shape=[3, 3, 256, 512], name='w_7')
			# self.b[7] = self.init_b(shape=[512], name='b_7')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[7], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_4_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[8] = self.init_w(shape=[3, 3, 512, 512], name='w_8')
			# self.b[8] = self.init_b(shape=[512], name='b_8')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[8], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_4_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')
			self.result_from_contract_layer[4] = result_relu_2  # 该层结果临时保存, 供上采样使用

			# maxpool
			result_maxpool = tf.nn.max_pool(
				value=result_relu_2, ksize=[1, 2, 2, 1],
				strides=[1, 2, 2, 1], padding='VALID', name='maxpool')

			# dropout
			result_dropout = tf.nn.dropout(x=result_maxpool, keep_prob=self.keep_prob)

		# layer 5 (bottom)
		with tf.name_scope('layer_5'):
			# conv_1
			self.w[9] = self.init_w(shape=[3, 3, 512, 1024], name='w_9')
			# self.b[9] = self.init_b(shape=[1024], name='b_9')
			result_conv_1 = tf.nn.conv2d(
				input=result_dropout, filter=self.w[9], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_5_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[10] = self.init_w(shape=[3, 3, 1024, 1024], name='w_10')
			# self.b[10] = self.init_b(shape=[1024], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[10], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_5_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')

			# up sample
			self.w[11] = self.init_w(shape=[2, 2, 512, 1024], name='w_11')
			# self.b[11] = self.init_b(shape=[512], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[11],
				output_shape=[batch_size, 64, 64, 512],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			normed_batch = self.batch_norm(x=result_up, is_training=self.is_traing, name='layer_5_conv_up')
			result_relu_3 = tf.nn.relu(features=normed_batch, name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 6
		with tf.name_scope('layer_6'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[4], result_from_upsampling=result_dropout)
			# print(result_merge)

			# conv_1
			self.w[12] = self.init_w(shape=[3, 3, 1024, 512], name='w_12')
			# self.b[12] = self.init_b(shape=[512], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[12], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_6_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[13] = self.init_w(shape=[3, 3, 512, 512], name='w_10')
			# self.b[13] = self.init_b(shape=[512], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[13], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_6_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')
			# print(result_relu_2.shape[1])

			# up sample
			self.w[14] = self.init_w(shape=[2, 2, 256, 512], name='w_11')
			# self.b[14] = self.init_b(shape=[256], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[14],
				output_shape=[batch_size, 128, 128, 256],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			normed_batch = self.batch_norm(x=result_up, is_training=self.is_traing, name='layer_6_conv_up')
			result_relu_3 = tf.nn.relu(features=normed_batch, name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 7
		with tf.name_scope('layer_7'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[3], result_from_upsampling=result_dropout)

			# conv_1
			self.w[15] = self.init_w(shape=[3, 3, 512, 256], name='w_12')
			# self.b[15] = self.init_b(shape=[256], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[15], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_7_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[16] = self.init_w(shape=[3, 3, 256, 256], name='w_10')
			# self.b[16] = self.init_b(shape=[256], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[16], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_7_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')

			# up sample
			self.w[17] = self.init_w(shape=[2, 2, 128, 256], name='w_11')
			# self.b[17] = self.init_b(shape=[128], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[17],
				output_shape=[batch_size, 256, 256, 128],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			normed_batch = self.batch_norm(x=result_up, is_training=self.is_traing, name='layer_7_up')
			result_relu_3 = tf.nn.relu(features=normed_batch, name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 8
		with tf.name_scope('layer_8'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[2], result_from_upsampling=result_dropout)

			# conv_1
			self.w[18] = self.init_w(shape=[3, 3, 256, 128], name='w_12')
			# self.b[18] = self.init_b(shape=[128], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[18], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_8_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[19] = self.init_w(shape=[3, 3, 128, 128], name='w_10')
			# self.b[19] = self.init_b(shape=[128], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[19], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_8_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')

			# up sample
			self.w[20] = self.init_w(shape=[2, 2, 64, 128], name='w_11')
			# self.b[20] = self.init_b(shape=[64], name='b_11')
			result_up = tf.nn.conv2d_transpose(
				value=result_relu_2, filter=self.w[20],
				output_shape=[batch_size, 512, 512, 64],
				strides=[1, 2, 2, 1], padding='VALID', name='Up_Sample')
			normed_batch = self.batch_norm(x=result_up, is_training=self.is_traing, name='layer_8_up')
			result_relu_3 = tf.nn.relu(features=normed_batch, name='relu_3')

			# dropout
			result_dropout = tf.nn.dropout(x=result_relu_3, keep_prob=self.keep_prob)

		# layer 9
		with tf.name_scope('layer_9'):
			# copy, crop and merge
			result_merge = self.copy_and_crop_and_merge(
				result_from_contract_layer=self.result_from_contract_layer[1], result_from_upsampling=result_dropout)

			# conv_1
			self.w[21] = self.init_w(shape=[3, 3, 128, 64], name='w_12')
			# self.b[21] = self.init_b(shape=[64], name='b_12')
			result_conv_1 = tf.nn.conv2d(
				input=result_merge, filter=self.w[21], strides=[1, 1, 1, 1], padding='SAME', name='conv_1')
			normed_batch = self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='layer_9_conv_1')
			result_relu_1 = tf.nn.relu(features=normed_batch, name='relu_1')

			# conv_2
			self.w[22] = self.init_w(shape=[3, 3, 64, 64], name='w_10')
			# self.b[22] = self.init_b(shape=[64], name='b_10')
			result_conv_2 = tf.nn.conv2d(
				input=result_relu_1, filter=self.w[22], strides=[1, 1, 1, 1], padding='SAME', name='conv_2')
			normed_batch = self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='layer_9_conv_2')
			result_relu_2 = tf.nn.relu(features=normed_batch, name='relu_2')

			# convolution to [batch_size, OUTPIT_IMG_WIDE, OUTPUT_IMG_HEIGHT, CLASS_NUM]
			self.w[23] = self.init_w(shape=[1, 1, 64, CLASS_NUM], name='w_11')
			# self.b[23] = self.init_b(shape=[CLASS_NUM], name='b_11')
			result_conv_3 = tf.nn.conv2d(
				input=result_relu_2, filter=self.w[23],
				strides=[1, 1, 1, 1], padding='VALID', name='conv_3')
			normed_batch = self.batch_norm(x=result_conv_3, is_training=self.is_traing, name='layer_9_conv_3')
			# self.prediction = tf.nn.relu(tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='relu_3')
			# self.prediction = tf.nn.sigmoid(x=tf.nn.bias_add(result_conv_3, self.b[23], name='add_bias'), name='sigmoid_1')
			self.prediction = normed_batch

		# softmax loss
		with tf.name_scope('softmax_loss'):
			# using one-hot
			# self.loss = \
			# 	tf.nn.softmax_cross_entropy_with_logits(labels=self.cast_label, logits=self.prediction, name='loss')

			# not using one-hot
			self.loss = \
				tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_label, logits=self.prediction, name='loss')
			self.loss_mean = tf.reduce_mean(self.loss)
			tf.add_to_collection(name='loss', value=self.loss_mean)
			self.loss_all = tf.add_n(inputs=tf.get_collection(key='loss'))

		# accuracy
		with tf.name_scope('accuracy'):
			# using one-hot
			# self.correct_prediction = tf.equal(tf.argmax(self.prediction, axis=3), tf.argmax(self.cast_label, axis=3))

			# not using one-hot
			self.correct_prediction = \
				tf.equal(tf.argmax(input=self.prediction, axis=3, output_type=tf.int32), self.input_label)
			self.correct_prediction = tf.cast(self.correct_prediction, tf.float32)
			self.accuracy = tf.reduce_mean(self.correct_prediction)

		# Gradient Descent
		with tf.name_scope('Gradient_Descent'):
			self.train_step = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(self.loss_all)

	def train(self):
		train_file_path = os.path.join(FLAGS.data_dir, TRAIN_SET_NAME)
		train_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(train_file_path), num_epochs=EPOCH_NUM, shuffle=True)
		ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt")
		train_images, train_labels = read_image_batch(train_image_filename_queue, TRAIN_BATCH_SIZE)
		tf.summary.scalar("loss", self.loss_mean)
		tf.summary.scalar('accuracy', self.accuracy)
		merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			try:
				epoch = 1
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([train_images, train_labels])  # 在会话中取出image和label
					# print(label)
					lo, acc, summary_str = sess.run(
						[self.loss_mean, self.accuracy, merged_summary],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
							self.lamb: 0.004, self.is_traing: True}
					)
					summary_writer.add_summary(summary_str, epoch)
					# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					if epoch % 10 == 0:
						print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					sess.run(
						[self.train_step],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
							self.lamb: 0.004, self.is_traing: True}
					)
					epoch += 1
			except tf.errors.OutOfRangeError:
				print('Done training -- epoch limit reached')
			finally:
				# When done, ask the threads to stop.
				all_parameters_saver.save(sess=sess, save_path=ckpt_path)
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print("Done training")

	def validate(self):
		validation_file_path = os.path.join(FLAGS.data_dir, VALIDATION_SET_NAME)
		validation_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(validation_file_path), num_epochs=1, shuffle=True)
		ckpt_path = CHECK_POINT_PATH
		validation_images, validation_labels = read_image_batch(validation_image_filename_queue, VALIDATION_BATCH_SIZE)
		# tf.summary.scalar("loss", self.loss_mean)
		# tf.summary.scalar('accuracy', self.accuracy)
		# merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			try:
				epoch = 1
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([validation_images, validation_labels])  # 在会话中取出image和label
					# print(label)
					lo, acc = sess.run(
						[self.loss_mean, self.accuracy],
						feed_dict={
							self.input_image: example, self.input_label: label, self.keep_prob: 1.0,
							self.lamb: 0.004, self.is_traing: False}
					)
					# summary_writer.add_summary(summary_str, epoch)
					# print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					if epoch % 1 == 0:
						print('num %d, loss: %.6f and accuracy: %.6f' % (epoch, lo, acc))
					epoch += 1
			except tf.errors.OutOfRangeError:
				print('Done validating -- epoch limit reached')
			finally:
				# When done, ask the threads to stop.
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print('Done validating')

	def test(self):
		import cv2
		test_file_path = os.path.join(FLAGS.data_dir, TEST_SET_NAME)
		test_image_filename_queue = tf.train.string_input_producer(
			string_tensor=tf.train.match_filenames_once(test_file_path), num_epochs=1, shuffle=True)
		ckpt_path = CHECK_POINT_PATH
		test_images, test_labels = read_image_batch(test_image_filename_queue, TEST_BATCH_SIZE)
		# tf.summary.scalar("loss", self.loss_mean)
		# tf.summary.scalar('accuracy', self.accuracy)
		# merged_summary = tf.summary.merge_all()
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			coord = tf.train.Coordinator()
			threads = tf.train.start_queue_runners(coord=coord)
			sum_acc = 0.0
			try:
				epoch = 0
				while not coord.should_stop():
					# Run training steps or whatever
					# print('epoch ' + str(epoch))
					example, label = sess.run([test_images, test_labels])  # 在会话中取出image和label
					# print(label)
					image, acc = sess.run(
						[tf.argmax(input=self.prediction, axis=3), self.accuracy],
						feed_dict={
							self.input_image: example, self.input_label: label,
							self.keep_prob: 1.0, self.lamb: 0.004, self.is_traing: False
						}
					)
					sum_acc += acc
					epoch += 1
					cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % epoch), image[0] * 255)
					if epoch % 1 == 0:
						print('num %d accuracy: %.6f' % (epoch, acc))
			except tf.errors.OutOfRangeError:
				print('Done testing -- epoch limit reached \n Average accuracy: %.2f%%' % (sum_acc / TEST_SET_SIZE * 100))
			finally:
				# When done, ask the threads to stop.
				coord.request_stop()
			# coord.request_stop()
			coord.join(threads)
		print('Done testing')

	def predict(self):
		import cv2
		import glob
		import numpy as np
		predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
		print(len(predict_file_path))
		if not os.path.lexists(PREDICT_SAVED_DIRECTORY):
			os.mkdir(PREDICT_SAVED_DIRECTORY)
		ckpt_path = os.path.join(FLAGS.model_dir, "model.ckpt") # CHECK_POINT_PATH
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			for index, image_path in enumerate(predict_file_path):
				# image = cv2.imread(image_path, flags=0)
				image = np.reshape(
					a=cv2.imread(image_path, flags=0),
					newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
				predict_image = sess.run(
					tf.argmax(input=self.prediction, axis=3),
					feed_dict={
						self.input_image: image, self.keep_prob: 1.0, self.lamb: 0.004, self.is_traing: False
					}
				)
				cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0])  # * 255
		print('Done prediction')


def main():
	net = Unet()
	CHECK_POINT_PATH = os.path.join(FLAGS.model_dir, "model.ckpt")
	# net.set_up_unet(TRAIN_BATCH_SIZE)
	# net.train()
	# net.set_up_unet(VALIDATION_BATCH_SIZE)
	# net.validate()
	# net.set_up_unet(TEST_BATCH_SIZE)
	# net.test()
	net.set_up_unet(PREDICT_BATCH_SIZE)
	net.predict()

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	# 数据地址
	parser.add_argument(
		'--data_dir', type=str, default='../data_set/my_set',
		help='Directory for storing input data')

	# 模型保存地址
	parser.add_argument(
		'--model_dir', type=str, default='../data_set/saved_models',
		help='output model path')

	# 日志保存地址
	parser.add_argument(
		'--tb_dir', type=str, default='../data_set/logs',
		help='TensorBoard log path')

	FLAGS, _ = parser.parse_known_args()
	# write_img_to_tfrecords()
	# read_check_tfrecords()
	main()