import sys
import cPickle as pickle
import numpy as np
import os
from scipy.misc import imread

import numpy as np
import lmdb
import caffe

def load_CIFAR_batch(filename, pad=True):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = pickle.load(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).astype(np.uint8)
    padded = np.zeros((10000, 3, 40, 40), dtype=np.uint8)
    padded[:,:,:,:] = 128
    padded[:,:,4:-4, 4:-4] = X
    Y = np.array(Y, dtype=np.int64) 
    if not pad:
      return X, Y
    return padded, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)    
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  idx = np.arange(len(Ytr))
  np.random.shuffle(idx)
  print 'shuffle training data', len(idx)
  Xtr = Xtr[idx]
  Ytr = Ytr[idx]
  print idx
  print 'tr label',Ytr.min(), Ytr.max()
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'), pad=False)
  print 'te label',Yte.min(), Yte.max()
  print Xtr.shape
  print Ytr.shape
  print Xte.shape
  print Yte.shape
  return Xtr, Ytr, Xte, Yte

def py2lmdb(X, y, save_path):
  # Let's pretend this is interesting data

  assert X.dtype == np.uint8
  N = X.shape[0]
  assert N == y.shape[0], str(N) + ' ' + str(y.shape)
  
  
  # We need to prepare the database for the size. We'll set it 10 times
  # greater than what we theoretically need. There is little drawback to
  # setting this too big. If you still run into problem after raising
  # this, you might want to try saving fewer entries in a single
  # transaction.
  map_size = X.nbytes * 10
  
  env = lmdb.open(save_path, map_size=map_size)
  
  with env.begin(write=True) as txn:
      # txn is a Transaction object
      for i in range(N):
          datum = caffe.proto.caffe_pb2.Datum()
          datum.channels = X.shape[1]
          datum.height = X.shape[2]
          datum.width = X.shape[3]
          datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
          datum.label = int(y[i])
          str_id = '{:08}'.format(i)
  
          # The encode is only essential in Python 3
          txn.put(str_id.encode('ascii'), datum.SerializeToString())


if __name__ == '__main__':
  root = sys.argv[1]
  Xtr, Ytr, Xte, Yte = load_CIFAR10(root)
  paths = [ os.path.join(root, i) for i in ['train', 'test']]
  py2lmdb(Xtr, Ytr, paths[0])
  py2lmdb(Xte, Yte, paths[1])
  for i in paths:
    print 'saved to', i