import os
import torch
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
import torch.nn as nn
import argparse
import models
import math

parser = argparse.ArgumentParser(description='sample.py')

parser.add_argument('-init', default='The meaning of life is ',
                    help="""Initial text """)
parser.add_argument('-load_model', default='',
                    help="""Model filename to load""")                   
parser.add_argument('-seq_length', type=int, default=50,
                    help="""Maximum sequence length""")
parser.add_argument('-temperature', type=float, default=0.4,
                    help="""Temperature for sampling.""")
parser.add_argument('-neuron', type=int, default=0,
                    help="""Neuron to read.""")
parser.add_argument('-overwrite', type=float, default=0,
                    help="""Value used to overwrite the neuron. 0 means don't overwrite.""")
parser.add_argument('-layer', type=int, default=-1,
                    help="""Layer to read. -1 = last layer""")
# GPU
parser.add_argument('-cuda', action='store_true',
                    help="""Use CUDA""")
                    
opt = parser.parse_args()    


def batchify(data, bsz):
    tokens = len(data.encode())
    ids = torch.LongTensor(tokens)
    token = 0
    for char in data.encode():
        ids[token] = char
        token += 1
    nbatch = ids.size(0) // bsz
    ids = ids.narrow(0, 0, nbatch * bsz)
    ids = ids.view(bsz, -1).t().contiguous()
    return ids        


def color(p):
	p = math.tanh(3*p)*.5+.5
	q = 1.-p*1.3
	r = 1.-abs(0.5-p)*1.3+.3*q
	p=1.3*p-.3
	i = int(p*255)	
	j = int(q*255)
	k = int(r*255)
	if j<0:
		j=0	
	if k<0:
		k=0
	if k >255:
		k=255
	if i<0:
		i = 0
	return ('\033[38;2;%d;%d;%dm' % (j, k, i)).encode()



batch_size = 1

checkpoint = torch.load(opt.load_model)
embed = checkpoint['embed']
rnn = checkpoint['rnn']

loss_fn = nn.CrossEntropyLoss() 


text = batchify(opt.init, batch_size)

def make_cuda(state):
    if isinstance(state, tuple):
    	return (state[0].cuda(), state[1].cuda())
    else:
    	return state.cuda()

batch = Variable(text) 
states = rnn.state0(batch_size)
if isinstance(states, tuple):
	hidden, cell = states
else:
	hidden = states
last = hidden.size(0)-1
if opt.layer <= last and opt.layer >= 0:
	last = opt.layer

if opt.cuda:
    batch =batch.cuda()
    states = make_cuda(states)
    embed.cuda()
    rnn.cuda()
    
loss_avg = 0
loss = 0
gen = bytearray()
for t in range(text.size(0)):                           
    emb = embed(batch[t])
    ni = (batch[t]).data[0]
    states, output = rnn(emb, states)
    if isinstance(states, tuple):
        hidden, cell = states
    else:
        hidden = states
    feat = hidden.data[last,0,opt.neuron]
    if ni< 128:
        col = color(feat)
        gen+=(col)
    gen.append(ni)        	

print(opt.init)        	


if opt.temperature == 0:
    topv, topi = output.data.topk(1)
    ni = topi[0][0]
    gen.append(ni)
    inp = Variable(topi[0], volatile=True)
    if opt.cuda:
        inp = inp.cuda()  

    for t in range(opt.seq_length):
        
        emb = embed(inp)
        states, output = rnn(emb, states)        
        topv, topi = output.data.topk(1)
        ni = topi[0][0]
        gen.append(ni)
        inp = Variable(topi[0])
        if opt.cuda:
            inp = inp.cuda() 

else:
    probs = F.softmax(output[0].squeeze().div(opt.temperature)).data.cpu()
    ni = torch.multinomial(probs,1)[0]
    feat = hidden.data[last,0,opt.neuron]
    
    if ni < 128:
    	col = color(feat)
    	gen+=(col)
    	
    gen.append(ni)	
    
    
    inp = Variable(torch.LongTensor([ni]), volatile=True) 
    if opt.cuda:
        inp = inp.cuda()      
    for t in range(opt.seq_length):
        
        emb = embed(inp)
        states, output = rnn(emb, states)
        if isinstance(states, tuple):
        	hidden, cell = states
        else:
        	hidden = states
        feat = hidden.data[last,0,opt.neuron] 
        if isinstance(output, list): 
        	output =output[0]      
        probs = F.softmax(output.squeeze().div(opt.temperature)).data.cpu()
        ni = torch.multinomial(probs,1)[0]
        
        if ni< 128:
        	col = color(feat)
        	gen+=(col)
        gen.append(ni)
        inp = Variable(torch.LongTensor([ni]))
        if opt.cuda:
            inp = inp.cuda() 
        if opt.overwrite != 0: 
        	hidden.data[last,0,opt.neuron] = opt.overwrite

gen+=('\033[0m').encode()

print(gen.decode("utf-8",errors = 'ignore' ))