from easydict import EasyDict as edict

__C = edict()
# Consumers can get config by:
#   from fast_rcnn_config import cfg
cfg = __C

#
# Common
#
__C.SUB_CONFIG_FILE = []
__C.DATASET = './experiments/dataset/shapenet_1000.json'  # yaml/json file that specifies a dataset (training/testing)
__C.NET_NAME = 'res_gru_net'
__C.PROFILE = False

__C.CONST = edict()
__C.CONST.DEVICE = 'gpu0'
__C.CONST.RNG_SEED = 0
__C.CONST.IMG_W = 127
__C.CONST.IMG_H = 127
__C.CONST.N_VOX = 32
__C.CONST.N_VIEWS = 5
__C.CONST.BATCH_SIZE = 36
__C.CONST.NETWORK_CLASS = 'ResidualGRUNet'
__C.CONST.WEIGHTS = ''  # when set, load the weights from the file

#
# Directories
#
__C.DIR = edict()
# Path where taxonomy.json is stored
__C.DIR.SHAPENET_QUERY_PATH = './ShapeNet/ShapeNetVox32/'
__C.DIR.MODEL_PATH = './ShapeNet/ShapeNetCore.v1/%s/%s/model.obj'
__C.DIR.VOXEL_PATH = './ShapeNet/ShapeNetVox32/%s/%s/model.binvox'
__C.DIR.RENDERING_PATH = './ShapeNet/ShapeNetRendering/%s/%s/rendering'
__C.DIR.OUT_PATH = './output/default'

#
# Training
#
__C.TRAIN = edict()

__C.TRAIN.RESUME_TRAIN = False
__C.TRAIN.INITIAL_ITERATION = 0  # when the training resumes, set the iteration number
__C.TRAIN.USE_REAL_IMG = False
__C.TRAIN.DATASET_PORTION = [0, 0.8]

# Data worker
__C.TRAIN.NUM_WORKER = 1  # number of data workers
__C.TRAIN.NUM_ITERATION = 60000  # maximum number of training iterations
__C.TRAIN.WORKER_LIFESPAN = 100  # if use blender, kill a worker after some iteration to clear cache
__C.TRAIN.WORKER_CAPACITY = 1000  # if use OSG, load only limited number of models at a time
__C.TRAIN.NUM_RENDERING = 24
__C.TRAIN.NUM_VALIDATION_ITERATIONS = 24
__C.TRAIN.VALIDATION_FREQ = 2000
__C.TRAIN.NAN_CHECK_FREQ = 2000
__C.TRAIN.RANDOM_NUM_VIEWS = True  # feed in random # views if n_views > 1

__C.QUEUE_SIZE = 15  # maximum number of minibatches that can be put in a data queue

# Data augmentation
__C.TRAIN.RANDOM_CROP = True
__C.TRAIN.PAD_X = 10
__C.TRAIN.PAD_Y = 10
__C.TRAIN.FLIP = True

# For no random bg images, add random colors
__C.TRAIN.NO_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
__C.TRAIN.RANDOM_BACKGROUND = False
__C.TRAIN.SIMPLE_BACKGROUND_RATIO = 0.5  # ratio of the simple backgrounded images

# Learning
# For SGD use 0.1, for ADAM, use 0.0001
__C.TRAIN.DEFAULT_LEARNING_RATE = 1e-4
__C.TRAIN.POLICY = 'adam'  # def: sgd, adam
# The EasyDict can't use dict with integers as keys
__C.TRAIN.LEARNING_RATES = {'20000': 1e-5, '60000': 1e-6}
__C.TRAIN.MOMENTUM = 0.90
# weight decay or regularization constant. If not set, the loss can diverge
# after the training almost converged since weight can increase indefinitely
# (for cross entropy loss). Too high regularization will also hinder training.
__C.TRAIN.WEIGHT_DECAY = 0.00005
__C.TRAIN.LOSS_LIMIT = 2  # stop training if the loss exceeds the limit
__C.TRAIN.SAVE_FREQ = 10000  # weights will be overwritten every save_freq
__C.TRAIN.PRINT_FREQ = 40

#
# Testing options
#
__C.TEST = edict()
__C.TEST.EXP_NAME = 'test'
__C.TEST.USE_IMG = False
__C.TEST.MODEL_ID = []
__C.TEST.DATASET_PORTION = [0.8, 1]
__C.TEST.SAMPLE_SIZE = 0
__C.TEST.IMG_PATH = ''
__C.TEST.AZIMUTH = []
__C.TEST.NO_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]

__C.TEST.VISUALIZE = False
__C.TEST.VOXEL_THRESH = [0.4]


def _merge_a_into_b(a, b):
    """Merge config dictionary a into config dictionary b, clobbering the
    options in b whenever they are also specified in a.
    """
    if type(a) is not edict:
        return

    for k, v in a.items():
        # a must specify keys that are in b
        if k not in b.keys():
            raise KeyError('{} is not a valid config key'.format(k))

        # the types must match, too
        if type(b[k]) is not type(v):
            raise ValueError(('Type mismatch ({} vs. {}) '
                              'for config key: {}').format(type(b[k]), type(v), k))

        # recursively merge dicts
        if type(v) is edict:
            try:
                _merge_a_into_b(a[k], b[k])
            except:
                print('Error under config key: {}'.format(k))
                raise
        else:
            b[k] = v


def cfg_from_file(filename):
    """Load a config file and merge it into the default options."""
    import yaml
    with open(filename, 'r') as f:
        yaml_cfg = edict(yaml.load(f))

    _merge_a_into_b(yaml_cfg, __C)


def cfg_from_list(cfg_list):
    """Set config keys via list (e.g., from command line)."""
    from ast import literal_eval
    assert len(cfg_list) % 2 == 0
    for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
        key_list = k.split('.')
        d = __C
        for subkey in key_list[:-1]:
            assert subkey in d.keys()
            d = d[subkey]
        subkey = key_list[-1]
        assert subkey in d.keys()
        try:
            value = literal_eval(v)
        except:
            # handle the case when v is a string literal
            value = v
        assert type(value) == type(d[subkey]), \
            'type {} does not match original type {}'.format(
            type(value), type(d[subkey]))
        d[subkey] = value