import os
import time

import h5py
import numpy as np
import torch
from torch.utils.data.sampler import BatchSampler

class BalancedRandomBatchSampler(BatchSampler):
    def __init__(self, data_source, batch_size, num_epochs=50, balance=True, path=None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.order = balanced_shuffle(data_source.target_tensor, num_epochs, path) \
                     if balance else shuffle(data_source.target_tensor, num_epochs, path)
        self.epoch = 0
    
    def __iter__(self):
        for i in range(len(self)):
            batch = self.order[self.epoch, i*self.batch_size:(i+1)*self.batch_size]
            batch = batch[batch >= 0].tolist()
            yield iter(batch)
        self.epoch += 1
        if self.epoch >= self.order.size(0):
            self.epoch = 0
    
    def __len__(self):
        return self.order.size(1) // self.batch_size

def balanced_shuffle(labels, num_epochs=50, path=None, start_time=time.time()):

    order_path = '{path}/balanced_order_{num_epochs}.h5' \
                       .format(path=path, num_epochs=num_epochs)
    if path is not None and os.path.isfile(order_path):
        with h5py.File(order_path, 'r') as f:
            order = f['order'][:]
    else:
        evenness = 5 # batch_size | evenness*num_classes
        classes = np.unique(labels.numpy())
        num_classes = len(classes)
        loc_data_per_class = [np.argwhere(labels.numpy() == k).flatten() for k in classes]
        num_data_per_class = [(labels.numpy() == k).sum() for k in classes]
        max_data_per_class = max(num_data_per_class)
        num_loc_split = (max_data_per_class // evenness) * np.ones(evenness, dtype=int)
        num_loc_split[:(max_data_per_class % evenness)] += 1
        loc_split = [0]
        loc_split.extend(np.cumsum(num_loc_split).tolist())
        order = -np.ones([num_epochs, max_data_per_class*num_classes], dtype=int)
        for epoch in range(num_epochs):
            order_e = -np.ones([max_data_per_class, num_classes], dtype=int)
            for k in classes:
                loc_k = np.random.permutation(loc_data_per_class[k])
                for i in range(evenness):
                    loc_i = loc_k[loc_split[i]:loc_split[i+1]]
                    order_e[i:(len(loc_i)*evenness+i):evenness, k] = loc_i
            order[epoch] = order_e.flatten()
            print_freq = min([100, (num_epochs-1) // 5 + 1])
            print_me = (epoch == 0 or epoch == num_epochs-1 or (epoch+1) % print_freq == 0)
            if print_me:
                print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch+1, num_epochs=num_epochs), end='')
                print('generate balanced random order; {time:8.3f} s'.format(time=time.time()-start_time))
        
        if path is not None:
            with h5py.File(order_path, 'w') as f:
                f.create_dataset('order', data=order, compression='gzip', compression_opts=9)
    
    print('balanced random order; {time:8.3f} s'.format(time=time.time()-start_time))
    return torch.from_numpy(order)

def shuffle(labels, num_epochs=50, path=None, start_time=time.time()):

    order_path = '{path}/order_{num_epochs}.h5' \
                       .format(path=path, num_epochs=num_epochs)
    if path is not None and os.path.isfile(order_path):
        with h5py.File(order_path, 'r') as f:
            order = f['order'][:]
    else:
        order = -np.ones([num_epochs, labels.size(0)], dtype=int)
        for epoch in range(num_epochs):
            order[epoch] = np.random.permutation(labels.size(0))
            print_freq = min([100, (num_epochs-1) // 5 + 1])
            print_me = (epoch == 0 or epoch == num_epochs-1 or (epoch+1) % print_freq == 0)
            if print_me:
                print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch+1, num_epochs=num_epochs), end='')
                print('generate random order; {time:8.3f} s'.format(time=time.time()-start_time))
        
        if path is not None:
            with h5py.File(order_path, 'w') as f:
                f.create_dataset('order', data=order, compression='gzip', compression_opts=9)
    
    print('random order; {time:8.3f} s'.format(time=time.time()-start_time))
    return torch.from_numpy(order)