# coding: utf-8

# Learning to learn by gradient descent by gradient descent
# =========================#

# https://arxiv.org/abs/1611.03824
# https://yangsenius.github.io/blog/LSTM_Meta/
# author:yangsen
# #### “通过梯度下降来学习如何通过梯度下降学习”
# #### "learning to learn by gradient descent by gradient descent"
# #### 要让优化器学会这样   "为了更好地得到,要先去舍弃"  这样类似的知识!
# #### make the optimizer to learn the knowledge of "sometimes, in order to get it better, you have to give up first. "

import torch
import torch.nn as nn
from timeit import default_timer as timer
#####################      优化问题   ##########################
#####################      optimization  ##########################
USE_CUDA = False
DIM = 10
batchsize = 128

if torch.cuda.is_available():
    USE_CUDA = True  

print('\n\nUSE_CUDA = {}\n\n'.format(USE_CUDA))
 

def f(W,Y,x):
    """quadratic function : f(\theta) = \|W\theta - y\|_2^2"""
    if USE_CUDA:
        W = W.cuda()
        Y = Y.cuda()
        x = x.cuda()

    return ((torch.matmul(W,x.unsqueeze(-1)).squeeze()-Y)**2).sum(dim=1).mean(dim=0)

###############################################################

######################    手工的优化器   ###################
######################    hand-craft optimizer   ###################

def SGD(gradients, state, learning_rate=0.001):
   
    return -gradients*learning_rate, state

def RMS(gradients, state, learning_rate=0.01, decay_rate=0.9):
    if state is None:
        state = torch.zeros(DIM)
        if USE_CUDA == True:
            state = state.cuda()
            
    state = decay_rate*state + (1-decay_rate)*torch.pow(gradients, 2)
    update = -learning_rate*gradients / (torch.sqrt(state+1e-5))
    return update, state

def adam():
    return torch.optim.Adam()

##########################################################


#####################    自动 LSTM 优化器模型  ##########################
#####################    auto LSTM optimizer model  ##########################
class LSTM_optimizer_Model(torch.nn.Module):
    """LSTM优化器
       LSTM optimizer"""
    
    def __init__(self,input_size,output_size, hidden_size, num_stacks, batchsize, preprocess = True ,p = 10 ,output_scale = 1):
        super(LSTM_optimizer_Model,self).__init__()
        self.preprocess_flag = preprocess
        self.p = p
        self.input_flag = 2
        if preprocess != True:
             self.input_flag = 1
        self.output_scale = output_scale 
        self.lstm = torch.nn.LSTM(input_size*self.input_flag, hidden_size, num_stacks)
        self.Linear = torch.nn.Linear(hidden_size,output_size) #1-> output_size
        
    def LogAndSign_Preprocess_Gradient(self,gradients):
        """
        Args:
          gradients: `Tensor` of gradients with shape `[d_1, ..., d_n]`.
          p       : `p` > 0 is a parameter controlling how small gradients are disregarded 
        Returns:
          `Tensor` with shape `[d_1, ..., d_n-1, 2 * d_n]`. The first `d_n` elements
          along the nth dimension correspond to the `log output` \in [-1,1] and the remaining
          `d_n` elements to the `sign output`.
        """
        p  = self.p
        log = torch.log(torch.abs(gradients))
        clamp_log = torch.clamp(log/p , min = -1.0,max = 1.0)
        clamp_sign = torch.clamp(torch.exp(torch.Tensor(p))*gradients, min = -1.0, max =1.0)
        return torch.cat((clamp_log,clamp_sign),dim = -1) #在gradients的最后一维input_dims拼接 # concatenate in final dim
    
    def Output_Gradient_Increment_And_Update_LSTM_Hidden_State(self, input_gradients, prev_state):
        """LSTM的核心操作  core operation
        coordinate-wise LSTM """
        if prev_state is None: #init_state
            prev_state = (torch.zeros(Layers,batchsize,Hidden_nums),
                            torch.zeros(Layers,batchsize,Hidden_nums))
            if USE_CUDA :
                 prev_state = (torch.zeros(Layers,batchsize,Hidden_nums).cuda(),
                            torch.zeros(Layers,batchsize,Hidden_nums).cuda())
         			
        update , next_state = self.lstm(input_gradients, prev_state)
        update = self.Linear(update) * self.output_scale # transform the LSTM output to the target output dim 
        return update, next_state
    
    def forward(self,input_gradients, prev_state):
        if USE_CUDA:
            input_gradients = input_gradients.cuda()
        #pytorch requires the `torch.nn.lstm`'s input as(1,batchsize,input_dim)
        # original gradient.size()=torch.size[5] ->[1,1,5]
        gradients = input_gradients.unsqueeze(0)
        if self.preprocess_flag == True:
            gradients = self.LogAndSign_Preprocess_Gradient(gradients)
        update , next_state = self.Output_Gradient_Increment_And_Update_LSTM_Hidden_State(gradients , prev_state)
        # Squeeze to make it a single batch again.[1,1,5]->[5]
        update = update.squeeze().squeeze()
       
        return update , next_state
    
#################   优化器模型参数  ##############################
#################   Parameters of optimizer ##############################
Layers = 2
Hidden_nums = 20
Input_DIM = DIM
Output_DIM = DIM
output_scale_value=1

#######   构造一个优化器  #######
#######  construct a optimizer  #######
LSTM_optimizer = LSTM_optimizer_Model(Input_DIM, Output_DIM, Hidden_nums ,Layers , batchsize=batchsize,\
                preprocess=False,output_scale=output_scale_value)
print(LSTM_optimizer)

if USE_CUDA:
    LSTM_optimizer = LSTM_optimizer.cuda()
   

######################  优化问题目标函数的学习过程   ###############
######################  the learning process of optimizing the target function   ###############

class Learner( object ):
    """
    Args :
        `f` : 要学习的问题 the learning problem, also called `optimizee` in the paper
        `optimizer` : 使用的优化器 the used optimizer 
        `train_steps` : 对于其他SGD,Adam等是训练周期,对于LSTM训练时的展开周期  training steps for SGD and ADAM, unfolded step for LSTM train
        `retain_graph_flag=False`  : 默认每次loss_backward后 释放动态图 default: free the dynamic graph after the loss backward
        `reset_theta = False `  :  默认每次学习前 不随机初始化参数 default: do not initialize the theta
        `reset_function_from_IID_distirbution = True` : 默认从分布中随机采样函数  default: random sample from  distribution

    Return :
        `losses` : reserves each loss value in each iteration
        `global_loss_graph` : constructs the graph of all Unroll steps for LSTM's BPTT 
    """
    def __init__(self,    f ,   optimizer,  train_steps ,  
                                            eval_flag = False,
                                            retain_graph_flag=False,
                                            reset_theta = False ,
                                            reset_function_from_IID_distirbution = True):
        self.f = f
        self.optimizer = optimizer
        self.train_steps = train_steps
        #self.num_roll=num_roll
        self.eval_flag = eval_flag
        self.retain_graph_flag = retain_graph_flag
        self.reset_theta = reset_theta
        self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution  
        self.init_theta_of_f()
        self.state = None

        self.global_loss_graph = 0 # global loss for optimizing LSTM
        self.losses = []   # KEEP each loss of all epoches

    def init_theta_of_f(self,):  
        ''' 初始化 优化问题 f 的参数 
            initialize the theta of optimization f '''
        self.DIM = 10
        self.batchsize = 128
        self.W = torch.randn(batchsize,DIM,DIM) # represents IID
        self.Y = torch.randn(batchsize,DIM)
        self.x = torch.zeros(self.batchsize,self.DIM)
        self.x.requires_grad = True
        if USE_CUDA:
            self.W = self.W.cuda()
            self.Y = self.Y.cuda()
            self.x = self.x.cuda()
        
            
    def Reset_Or_Reuse(self , x , W , Y , state, num_roll):
        ''' re-initialize the `W, Y, x , state`  at the begining of each global training
            IF `num_roll` == 0    '''

        reset_theta =self.reset_theta
        reset_function_from_IID_distirbution = self.reset_function_from_IID_distirbution

       
        if num_roll == 0 and reset_theta == True:
            theta = torch.zeros(batchsize,DIM)
            theta_init_new =  theta.clone().detach().requires_grad_(True)
            x = theta_init_new
            
            
        ################   每次全局训练迭代,从独立同分布的Normal Gaussian采样函数     ##################
        ################   at the first iteration , sample from IID Normal Gaussian    ##################
        if num_roll == 0 and reset_function_from_IID_distirbution == True :
            W = torch.randn(batchsize,DIM,DIM) # represents IID
            Y = torch.randn(batchsize,DIM)     # represents IID
         
            
        if num_roll == 0:
            state = None
            print('reset the values of `W`, `x`, `Y` and `state` for this optimizer')
            
        if USE_CUDA:
            W = W.cuda()
            Y = Y.cuda()
            x = x.cuda()
            x.retain_grad()
          
            
        return  x , W , Y , state

    def __call__(self, num_roll=0) : 
        '''
        Total Training steps = Unroll_Train_Steps * the times of  `Learner` been called
        
        SGD,RMS,LSTM FROM defination above
         but Adam is adopted by pytorch~ This can be improved later'''
        f  = self.f 
        x , W , Y , state =  self.Reset_Or_Reuse(self.x , self.W , self.Y , self.state , num_roll )
        self.global_loss_graph = 0   #at the beginning of unroll, reset to 0 
        optimizer = self.optimizer
     
        if optimizer!='Adam':
            
            for i in range(self.train_steps):     
                loss = f(W,Y,x)
                #self.global_loss_graph += (0.8*torch.log10(torch.Tensor([i+1]))+1)*loss
                self.global_loss_graph += loss
              
                loss.backward(retain_graph=self.retain_graph_flag) # default as False,set to True for LSTMS
                update, state = optimizer(x.grad.clone().detach(), state)
                self.losses.append(loss)
             
                x = x + update  
                x.retain_grad()
                update.retain_grad()
                
            if state is not None:
                self.state = (state[0].detach(),state[1].detach())
                
            return self.losses ,self.global_loss_graph 

        else: #Pytorch Adam

            x.detach_()
            x.requires_grad = True
            optimizer= torch.optim.Adam( [x],lr=0.1 )
            
            for i in range(self.train_steps):
                
                optimizer.zero_grad()
                loss = f(W,Y,x)
                
                self.global_loss_graph += loss
                
                loss.backward(retain_graph=self.retain_graph_flag)
                optimizer.step()
                self.losses.append(loss.detach_())
                
            return self.losses, self.global_loss_graph


#######   LSTM 优化器的训练过程 Learning to learn   ###############
#######   LSTM training Learning to learn   ###############

def Learning_to_learn_global_training(optimizer, global_taining_steps, optimizer_Train_Steps, UnRoll_STEPS, Evaluate_period ,optimizer_lr=0.1):
    """ Training the LSTM optimizer . Learning to learn

    Args:   
        `optimizer` : DeepLSTMCoordinateWise optimizer model
        `global_taining_steps` : how many steps for optimizer training o可以ptimizee
        `optimizer_Train_Steps` : how many step for optimizer opimitzing each function sampled from IID.
        `UnRoll_STEPS` :: how many steps for LSTM optimizer being unrolled to construct a computing graph to BPTT.
    """
    global_loss_list = []
    Total_Num_Unroll = optimizer_Train_Steps // UnRoll_STEPS
    adam_global_optimizer = torch.optim.Adam(optimizer.parameters(),lr = optimizer_lr)

    LSTM_Learner = Learner(f, optimizer, UnRoll_STEPS, retain_graph_flag=True, reset_theta=True,)
  #这里考虑Batchsize代表IID的话,那么就可以不需要每次都重新IID采样
  # If regarding `Batchsize` as `IID` ,there is no need for reset the theta
  #That is ,reset_function_from_IID_distirbution = False else it is True

    best_sum_loss = 999999
    best_final_loss = 999999
    best_flag = False
    for i in range(Global_Train_Steps): 

        print('\n========================================> global training steps: {}'.format(i))

        for num in range(Total_Num_Unroll):
            
            start = timer()
            _,global_loss = LSTM_Learner(num)   

            adam_global_optimizer.zero_grad()
            global_loss.backward() 
       
            adam_global_optimizer.step()
            # print('xxx',[(z.grad,z.requires_grad) for z in optimizer.lstm.parameters()  ])
            global_loss_list.append(global_loss.detach_())
            time = timer() - start
            #if i % 10 == 0:
            print('-> time consuming [{:.1f}s] optimizer train steps :  [{}] | Global_Loss = [{:.1f}]  '\
                  .format(time,(num +1)* UnRoll_STEPS,global_loss,))

        if (i + 1) % Evaluate_period == 0:
            
            best_sum_loss, best_final_loss, best_flag  = evaluate(best_sum_loss,best_final_loss,best_flag , optimizer_lr)
    return global_loss_list,best_flag


def evaluate(best_sum_loss,best_final_loss, best_flag,lr):
    print('\n --------> evalute the model')

    STEPS = 100
    x = np.arange(STEPS)
    Adam = 'Adam'

    LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True, retain_graph_flag=True)
    
    SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
    RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
    Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)


    sgd_losses, sgd_sum_loss = SGD_Learner()
    rms_losses, rms_sum_loss = RMS_Learner()
    adam_losses, adam_sum_loss = Adam_Learner()
    lstm_losses, lstm_sum_loss = LSTM_learner()

    p1, = plt.plot(x, sgd_losses, label='SGD')
    p2, = plt.plot(x, rms_losses, label='RMS')
    p3, = plt.plot(x, adam_losses, label='Adam')
    p4, = plt.plot(x, lstm_losses, label='LSTM')
    plt.yscale('log')
    plt.legend(handles=[p1, p2, p3, p4])
    plt.title('Losses')
    plt.pause(1.5)
    #plt.show()
    print("sum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))
    plt.close() 
    torch.save(LSTM_optimizer.state_dict(),'current_LSTM_optimizer_ckpt.pth')

    try:
        best = torch.load('best_loss.txt')
    except IOError:
        print ('can not find best_loss.txt')
        now_sum_loss = lstm_sum_loss.cpu()
        now_final_loss = lstm_losses[-1].cpu()
        pass
    else:
        best_sum_loss = best[0].cpu()
        best_final_loss = best[1].cpu()
        now_sum_loss = lstm_sum_loss.cpu()
        now_final_loss = lstm_losses[-1].cpu()

        print(" ==> History: sum loss = [{:.1f}] \t| final loss = [{:.2f}]".format(best_sum_loss,best_final_loss))
        print(" ==> Current: sum loss = [{:.1f}] \t| final loss = [{:.2f}]".format(now_sum_loss,now_final_loss))

    # save the best model according to the conditions below
    # there may be several choices to make a trade-off
    if now_final_loss < best_final_loss: # and  now_sum_loss < best_sum_loss:
        
        best_final_loss = now_final_loss
        best_sum_loss =  now_sum_loss 
        
        print('\n\n===> update new best of final LOSS[{}]: =  {}, best_sum_loss ={}'.format(STEPS, best_final_loss,best_sum_loss))
        torch.save(LSTM_optimizer.state_dict(),'best_LSTM_optimizer.pth')
        torch.save([best_sum_loss ,best_final_loss,lr ],'best_loss.txt')
        best_flag = True
        
    return best_sum_loss, best_final_loss, best_flag 

    
##########################   before learning LSTM optimizer ###############################
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

STEPS = 100
x = np.arange(STEPS)

Adam = 'Adam' # Adam in Pytorch

for _ in range(1): 
   
    SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
    RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
    Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)
    LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True,retain_graph_flag=True)

    sgd_losses, sgd_sum_loss = SGD_Learner()
    rms_losses, rms_sum_loss = RMS_Learner()
    adam_losses, adam_sum_loss = Adam_Learner()
    lstm_losses, lstm_sum_loss = LSTM_learner()

    p1, = plt.plot(x, sgd_losses, label='SGD')
    p2, = plt.plot(x, rms_losses, label='RMS')
    p3, = plt.plot(x, adam_losses, label='Adam')
    p4, = plt.plot(x, lstm_losses, label='LSTM')
    p1.set_dashes([2, 2, 2, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    p2.set_dashes([4, 2, 8, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    p3.set_dashes([3, 2, 10, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    plt.yscale('log')
    plt.legend(handles=[p1, p2, p3, p4])
    plt.title('Losses')
    plt.pause(2.5)
    print("\n\nsum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))



#################### Learning to learn (optimizing optimizer) ######################
Global_Train_Steps = 1000 #可修改 changeable
optimizer_Train_Steps = 100
UnRoll_STEPS = 20
Evaluate_period = 1 #可修改 changeable
optimizer_lr = 0.1 #可修改 changeable
global_loss_list ,flag = Learning_to_learn_global_training(   LSTM_optimizer,
                                                        Global_Train_Steps,
                                                        optimizer_Train_Steps,
                                                        UnRoll_STEPS,
                                                        Evaluate_period,
                                                          optimizer_lr)


######################################################################3#
##########################   show learning process results 
#torch.load('best_LSTM_optimizer.pth'))
#import numpy as np
#import matplotlib
#import matplotlib.pyplot as plt

#Global_T = np.arange(len(global_loss_list))
#p1, = plt.plot(Global_T, global_loss_list, label='Global_graph_loss')
#plt.legend(handles=[p1])
#plt.title('Training LSTM optimizer by gradient descent ')
#plt.show()


######################################################################3#
##########################   show contrast results SGD,ADAM, RMS ,LSTM ###############################
import copy
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

if flag ==True :
    print('\n==== > load best LSTM model')
    last_state_dict = copy.deepcopy(LSTM_optimizer.state_dict())
    torch.save(LSTM_optimizer.state_dict(),'final_LSTM_optimizer.pth')
    LSTM_optimizer.load_state_dict( torch.load('best_LSTM_optimizer.pth'))
    
LSTM_optimizer.load_state_dict(torch.load('best_LSTM_optimizer.pth'))
#LSTM_optimizer.load_state_dict(torch.load('final_LSTM_optimizer.pth'))
STEPS = 100
x = np.arange(STEPS)

Adam = 'Adam'

for _ in range(3): #可以多试几次测试实验,LSTM不稳定 for several test,  the trained LSTM is not stable?
    
    SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
    RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
    Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)
    LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True,retain_graph_flag=True)

    
    sgd_losses, sgd_sum_loss = SGD_Learner()
    rms_losses, rms_sum_loss = RMS_Learner()
    adam_losses, adam_sum_loss = Adam_Learner()
    lstm_losses, lstm_sum_loss = LSTM_learner()

    p1, = plt.plot(x, sgd_losses, label='SGD')
    p2, = plt.plot(x, rms_losses, label='RMS')
    p3, = plt.plot(x, adam_losses, label='Adam')
    p4, = plt.plot(x, lstm_losses, label='LSTM')
    p1.set_dashes([2, 2, 2, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    p2.set_dashes([4, 2, 8, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    p3.set_dashes([3, 2, 10, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    #p4.set_dashes([2, 2, 10, 2])  # 2pt line, 2pt break, 10pt line, 2pt break
    plt.yscale('log')
    plt.legend(handles=[p1, p2, p3, p4])
    plt.title('Losses')
    plt.show()
    print("\n\nsum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))