# -*- coding: utf-8 -*-
"""
Created on Tue Mar  7 12:44:30 2017

@author: Quantum Liu
"""
from keras import __version__ as kv
kv=int(kv[0])
import platform
pv=int(platform.python_version()[0])
import numpy as np
import scipy.io as sio
import itchat
from keras.callbacks import Callback
import time
import matplotlib  
matplotlib.use('Agg') # 
import matplotlib.pyplot as plt
from math import ceil
from itchat.content import TEXT
if pv>2:
    import _thread as th
else:
    import thread as th
import os
from os import system
import re
import traceback
import platform
from requests.exceptions import ConnectionError
#==============================================================================
#==============================================================================
# A log in function call it at first
#函数,需要首先调用
#==============================================================================
def login():
    if 'Windows' in platform.system():
        itchat.auto_login(enableCmdQR=1,hotReload=True)#
    else:
        itchat.auto_login(enableCmdQR=2,hotReload=True)#
    itchat.dump_login_status()#dump
#==============================================================================
# 
#==============================================================================
def send_text(text):
    #send text msgs to 'filehelper'
    #给文件助手发送文本信息
    try:
        itchat.send_msg(msg=text,toUserName='filehelper')
        return
    except (ConnectionError,NotImplementedError,KeyError):
        traceback.print_exc()
        print('\nConection error,failed to send the message!\n')
        return
    else:
        return
def send_img(filename):
    #send text imgs to 'filehelper'
    #给文件助手发送
    try:
        itchat.send_image(filename,toUserName='filehelper')
        return
    except (ConnectionError,NotImplementedError,KeyError):
        traceback.print_exc()
        print('\nConection error,failed to send the figure!\n')
        return
    else:
        return
#==============================================================================
#     
#==============================================================================
class sendmessage(Callback):
    #A subclss of keras.callbacks.Callback class
    #keras.callbacks.Callback class的子类
    def __init__(self,savelog=True,fexten=''):
        self.fexten=(fexten if fexten else '')#the name of log and figure files 
        self.savelog=bool(savelog)#save log or not
        
    def t_send(self,msg,toUserName='filehelper'):
        try:
            itchat.send_msg(msg=msg,toUserName=toUserName)
            return
        except (ConnectionError,NotImplementedError,KeyError):
            traceback.print_exc()
            print('\nConection error,failed to send the message!\n')
            return
        else:
            return
    def t_send_img(self,filename,toUserName='filehelper'):
        try:
            itchat.send_image(filename,toUserName=toUserName)
            return
        except (ConnectionError,NotImplementedError,KeyError):
            traceback.print_exc()
            print('\nConection error,failed to send the figure!\n')
            return
        else:
            return
       
            
    def shutdown(self,sec,save=True,filepath='temp.h5'):
        #Function used to shut down the computer
        #sec:waitting time to shut down the computer,sencond
        #save:wether saving the model
        #filepath:the filepath for saving the model
        #关机函数
        #sec:关机等待秒数
        #save:是否保存模型
        #filepath:保存模型的文件名
        if save:
            self.model.save(filepath, overwrite=True)
            self.t_send('Command accepted,the model has already been saved,shutting down the computer....', toUserName='filehelper')
        else:
            self.t_send('Command accepted,shutting down the computer....', toUserName='filehelper')
        if 'Windows' in platform.system():
            th.start_new_thread(system, ('shutdown -s -t %d' %sec,))
        else:
            m=(int(sec/60) if int(sec/60) else 1)
            th.start_new_thread(system, ('shutdown -h -t %d' %m,))
            
#==============================================================================
#         
#==============================================================================
    def cancel(self):
        #Cancel function to cancel shutting down the computer
        #取消关机函数
        self.t_send('Command accepted,cancel shutting down the computer....', toUserName='filehelper')
        if 'Windows' in platform.system():
            th.start_new_thread(system, ('shutdown -a',))
        else:
            th.start_new_thread(system, ('shutdown -c',))
#==============================================================================
#         
#==============================================================================
    def GetMiddleStr(self,content,startStr,endStr):
        #get the string between two specified strings
        #从指定的字符串之间截取字符串
        try:
          startIndex = content.index(startStr)
          if startIndex>=0:
            startIndex += len(startStr)
          endIndex = content.index(endStr)
          return content[startIndex:endIndex]
        except:
            return ''
#==============================================================================
# 
#==============================================================================
    def validateTitle(self,title):
        #transform a string to a validate filename
        #将字符串转化为合法文件名
        rstr = r"[\/\\\:\*\?\"\<\>\|]"  # '/\:*?"<>|'
        new_title = re.sub(rstr, "", title).replace(' ','')
        return new_title
#==============================================================================
#         
#==============================================================================
    def prog(self):#Show progress
        nb_batches_total=(self.params['epochs'] if kv-1 else self.params['nb_epoch'])*(self.params['samples'] if kv-1 else self.params['nb_sample'])/self.params['batch_size']
        nb_batches_epoch=(self.params['samples'] if kv-1 else self.params['nb_sample'])/self.params['batch_size']
        prog_total=(self.t_batches/nb_batches_total if nb_batches_total else 0)+0.01
        prog_epoch=(self.c_batches/nb_batches_epoch if nb_batches_epoch else 0)+0.01
        if self.t_epochs:
            now=time.time()
            t_mean=float(sum(self.t_epochs)) / len(self.t_epochs)
            eta_t=(now-self.train_start)*((1/prog_total)-1)
            eta_e=t_mean*(1-prog_epoch)
            t_end=time.asctime(time.localtime(now+eta_t))
            e_end=time.asctime(time.localtime(now+eta_e))
            m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(self.epoch[-1])+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
            self.t_send(msg=m)
            print(m)
        else:
            now=time.time()
            eta_t=(now-self.train_start)*((1/prog_total)-1)
            eta_e=(now-self.train_start)*((1/prog_epoch)-1)
            t_end=time.asctime(time.localtime(now+eta_t))
            e_end=time.asctime(time.localtime(now+eta_e))
            m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(len(self.epoch))+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
            self.t_send(msg=m)
            print(m)
            
#==============================================================================
# 
#==============================================================================
    def get_fig(self,level='all',metrics=['all']):
        #Get figure of train infomation
        #level:show the information of which level
        #metrics:metrics want to show,only show available ones
        #获取训练状态图表
        #level:显示batch级别函数epoch级别
        #metrics:希望获得的指标,只显示存在的指标,若指定了不存在的指标将不会被显示
        color_list='rgbyck'*10
        def batches(color_list='rgbyck'*10,metrics=['all']):
            if 'all' in metrics:
                m_available=list(self.logs_batches.keys())
            else:
                m_available=([val for val in list(self.logs_batches.keys()) if val in metrics]if[val for val in list(self.logs_batches.keys()) if val in metrics]else list(self.logs_batches.keys()))
            nb_rows_batches=int(ceil(len(m_available)*1.0/2))
            fig_batches=plt.figure('all_subs_batches')
            for i,k in enumerate(m_available):
                p=plt.subplot(nb_rows_batches,2,i+1)
                data=self.logs_batches[k]
                p.plot(range(len(data)),data,color_list[i]+'-',label=k)
                p.set_title(k+' in batches',fontsize=14)
                p.set_xlabel('batch',fontsize=10)
                p.set_ylabel(k,fontsize=10)
                #p.legend()
            filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_batches.jpg'
            plt.tight_layout()
            plt.savefig(filename)
            plt.close('all')
#==============================================================================
#             try:
#                 itchat.send_image(filename,toUserName='filehelper')
#             except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
#                 traceback.print_exc()
#                 print('\nConection error!\n')
#                 return
#==============================================================================
            self.t_send_img(filename,toUserName='filehelper')
            time.sleep(.5)
            self.t_send('Sending batches figure',toUserName='filehelper')
            return
#==============================================================================
#             
#==============================================================================
        def epochs(color_list='rgbyck'*10,metrics=['all']):
            if 'all' in metrics:
                m_available=list(self.logs_epochs.keys())
            else:
                m_available=([val for val in list(self.logs_epochs.keys()) if val in metrics]if[val for val in list(self.logs_epochs.keys()) if val in metrics]else list(self.logs_epochs.keys()))
            nb_rows_epochs=int(ceil(len(m_available)*1.0/2))
            fig_epochs=plt.figure('all_subs_epochs')
            for i,k in enumerate(m_available):
                p=plt.subplot(nb_rows_epochs,2,i+1)
                data=self.logs_epochs[k]
                p.plot(range(len(data)),data,color_list[i]+'-',label=k)
                p.set_title(k+' in epochs',fontsize=14)
                p.set_xlabel('epoch',fontsize=10)
                p.set_ylabel(k,fontsize=10)
            filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_epochs.jpg'
            plt.tight_layout()
            plt.savefig(filename)
            plt.close('all')
#==============================================================================
#             try:
#                 itchat.send_image(filename,toUserName='filehelper')
#             except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
#                 traceback.print_exc()
#                 print('\nConection error!\n')
#                 return
#==============================================================================
            self.t_send_img(filename,toUserName='filehelper')
            time.sleep(.5)
            self.t_send('Sending epochs figure',toUserName='filehelper')
            return
#==============================================================================
#             
#==============================================================================
        try:
            if not self.epoch and (level in ['all','epochs']):
                level='batches'
            if level=='all':
                batches(metrics=metrics)
                epochs(metrics=metrics)
                th.exit()
                return
            elif level=='epochs':
                epochs(metrics=metrics)
                th.exit()
                return
            elif level=='batches':
                batches(metrics=metrics)
                th.exit()
                return
            else:
                batches(metrics=metrics)
                epochs(metrics=metrics)
                th.exit()
                return
        except Exception:
            traceback.print_exc()
            self.t_send('Failed to send figure',toUserName='filehelper')
            th.exit()
            return
#==============================================================================
#             
#==============================================================================
    def gpu_status(self,av_type_list):
        for t in av_type_list:
            cmd='nvidia-smi -q --display='+t
            #print('\nCMD:',cmd,'\n')
            r=os.popen(cmd)
            info=r.readlines()
            r.close()
            content = " ".join(info)
            #print('\ncontent:',content,'\n')
            index=content.find('Attached GPUs')
            s=content[index:].replace(' ','').rstrip('\n')
            self.t_send(s, toUserName='filehelper')
            time.sleep(.5)
        #th.exit()
#==============================================================================
# 
#==============================================================================
    def on_train_begin(self, logs={}):
        self.epoch=[]
        self.t_epochs=[]
        self.t_batches=0
        self.logs_batches={}
        self.logs_epochs={}
        self.train_start=time.time()
        self.localtime = time.asctime( time.localtime(self.train_start) )
        self.mesg = 'Train started at: '+self.localtime
        self.t_send(self.mesg, toUserName='filehelper')
        self.stopped_epoch = (self.params['epochs'] if kv-1 else self.params['nb_epoch'])
        @itchat.msg_register(TEXT)
#==============================================================================
#         registe methods to reply msgs,similar to main()
#         注册消息响应方法,相当于主函数
#==============================================================================
        def manualstop(msg):
            text=msg['Text']
            stop_training_cmdlist=['Stop now',"That's enough",u'停止训练',u'放弃治疗']
            #The keywords of stop training,if any of them is in the msg you sent,the command would be accepted
            #停止训练的关键词列表,发送的消息中包含任意一项都可触发命令
            shut_down_cmdlist=[u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧']
            #The keywords of shutting down,similair to stop_training_cmdlist
            #关机关键词列表,和stop_training_cmdlist类似
            cancel_cmdlist=[u'取消','cancel','aaaa']
            #The keywords of cancel shutting down,similair to stop_training_cmdlist
            #取消关机关键词列表,和stop_training_cmdlist类似
            get_fig_cmdlist=[u'获取图表','Show me the figure']
            #The keywords of getting figure,similair to stop_training_cmdlist
            #获取图表关键词列表,和stop_training_cmdlist类似
            gpu_cmdlist=['GPU','gpu',u'显卡']
            type_list=['MEMORY', 'UTILIZATION', 'ECC', 'TEMPERATURE', 'POWER', 'CLOCK', 'COMPUTE', 'PIDS', 'PERFORMANCE', 'SUPPORTED_CLOCKS,PAGE_RETIREMENT', 'ACCOUNTING']
            prog_cmdlist=[u'进度','Progress']
            if msg['ToUserName']=='filehelper':
                print('\n',text,'\n')
                if 'Stop at' in text:
                    # Specify stop epoch,training will be stop after that epoch
                    #指定停止轮数,训练在指定epoch完成后会停止
                    #Example:send:'Stop at:8' from your phone,and then training will be stopped after epoch8
                    #例如:手机发送“Stop at:8”,训练将在epoch8完成后停止
                    self.stopped_epoch = int(re.findall(r"\d+\.?\d*",text)[0])
                    if kv-1:
                        self.params['epochs']=self.stopped_epoch
                    else:
                        self.params['nb_epoch']=self.stopped_epoch
                    self.t_send('Command accepted,training will be stopped at epoch'+str(self.stopped_epoch), toUserName='filehelper')
#==============================================================================
#                 
#==============================================================================
                if any((k in text) for k in stop_training_cmdlist) :
                    #Stop training after current epoch finished
                    #当前epoch完成后停止训练
                    #example:send:'Stop now' or send:'停止训练' from your phone,and then training will be stopped after current epoch
                    #例如:手机发送“停止训练”或者“Stop now”,训练将会在当前epoch完成后被停止
                    self.model.stop_training = True
                    self.t_send('Command accepted,stop training now at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')
#==============================================================================
#                 
#==============================================================================
                if any((k in text) for k in shut_down_cmdlist):
                    #Shutting down the computer after specified sec,specify waiting seconds and saved model filename by {sec} and [name](without .h5)
                    #在指定秒数后关机,用{sec}和[name]指定等待时间和保存文件名,文件名不包括.h5
                    #example:send:'Shut down now [test]{120}' from phone,the computer will be shut down after 120s,and save the model as test.h5
                    #or send:'Shut down now{120},don't save',then the model won't be saved.
                    if any((k in text) for k in [u'不保存模型',"don't save"]):
                        save=False
                    else:
                        save=True
                        filepath=(self.GetMiddleStr(text,'[',']')+'.h5' if self.GetMiddleStr(text,'[',']') else (self.fexten if self.fexten else self.validateTitle(self.localtime))+'.h5')
                    print('\n',filepath,'\n')
                    sec=int((self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}')>'30' else 120))
                    self.shutdown(sec,save=save,filepath=filepath)
#==============================================================================
#                     
#==============================================================================
                if any((k in text) for k in cancel_cmdlist):
                    #Cancel shutting down the computer
                    self.cancel()
#==============================================================================
#                     
#==============================================================================
                if any((k in text) for k in get_fig_cmdlist):   
                    #Get figure of train infomation,specify metrics and level you want to show by[metrics]and{level},defualt are both 'all'
                    #example:send:'Show me the figure [loss]{batches}' from phone,you will recive a jpg image of losses in batches
                    #send:'Show me the figure',you will recive two jpg images of all metrics in batches and epochs
                    #获取图表,通过[metrics]和{level}指定参数,如果没有指定则皆默认为’all'
                    #例如,手机发送"获取图表[loss]{batches}",会收到一个jpg格式的loss随batches变化的图片
                    #手机发送"获取图表",则会得到两张图片,分别是所有指标随batch和epoch的变化
                    metrics=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['all'])
                    level=(self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}') else 'all' )
                    if level in ['all','epochs','batches']:
                        th.start_new_thread(self.get_fig,(level,metrics))
                    else:
                        print("\nGot no level,using default 'all'\n")
                        self.t_send("Got no level,using default 'all'", toUserName='filehelper')
                        th.start_new_thread(self.get_fig,())
                if any((k in text) for k in gpu_cmdlist):
                    sp_type_lsit=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['MEMORY'])
                    av_type_list=[val for val in sp_type_lsit if val in type_list]
                    self.gpu_status(av_type_list,)
                if any((k in text) for k in prog_cmdlist):
                    try:
                        self.prog()
                    except:
                        traceback.print_exc()
        th.start_new_thread(itchat.run, ())
#==============================================================================
#     
#==============================================================================
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        for k in self.params['metrics']:
            if k in logs:
                self.logs_batches.setdefault(k, []).append(logs[k])
        self.c_batches+=1
        self.t_batches+=1
#==============================================================================
#                 
#==============================================================================
    def on_epoch_begin(self, epoch, logs=None):
        self.t_s=time.time()
        self.epoch.append(epoch)
        self.c_batches=0
        self.t_send('Epoch'+str(epoch+1)+'/'+str(self.stopped_epoch)+' started', toUserName='filehelper')
        self.mesg = ('Epoch:'+str(epoch+1)+' ')
#==============================================================================
#         
#==============================================================================
    def on_epoch_end(self, epoch, logs=None):
        for k in self.params['metrics']:
            if k in logs:
                self.mesg+=(k+': '+str(logs[k])[:5]+' ')
                self.logs_epochs.setdefault(k, []).append(logs[k])
#==============================================================================
#         except:
#             itchat.auto_login(hotReload=True,enableCmdQR=True)
#             itchat.dump_login_status()
#             self.t_send(self.mesg, toUserName='filehelper')
#==============================================================================
        if epoch+1>=self.stopped_epoch:
            self.model.stop_training = True
        logs = logs or {}
        self.epoch.append(epoch)
        self.t_epochs.append(time.time()-self.t_s)
        if self.savelog:
            sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_batches'+'.mat',{'log':np.array(self.logs_batches)})
            sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_epochs'+'.mat',{'log':np.array(self.logs_epochs)})
        th.start_new_thread(self.get_fig,())
#==============================================================================
#         try:
#             itchat.send(self.mesg, toUserName='filehelper')
#         except:
#             traceback.print_exc()
#             return
#==============================================================================
        self.t_send(self.mesg, toUserName='filehelper')
        return
#==============================================================================
#         
#==============================================================================
    def on_train_end(self, logs=None):
        self.t_send('Train stopped at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')