# -*- coding: utf-8 -*-
"""
Created on Fri Jan 29 16:56:11 2016

@author: konopczynski

Perform the dictionary learning for a given settings,
on the provided patches
"""

import numpy as np
import pickle
from sklearn.decomposition import MiniBatchDictionaryLearning
from sklearn.decomposition import DictionaryLearning
import warnings
import sys
sys.path.append('../')
import config
warnings.filterwarnings("ignore")


def learn_dictionary_mini(patches, n_c=512, a=1, n_i=800, n_j=3, b_s=3, es=5, fit_algorithm='lars'):
    """
    patches  - patches to learn on (should be normalized before)
    n_c - number of components (atoms) e.g. 512
    a   - alpha sparsity controlling parameter
    n_i - total number of iterations to perform
    b_s - batch size: number of samples in each mini-batch
    fit_algorithm - {‘lars’, ‘cd’}
    n_j - number of parallel jobs to run (number of threads)
    e_s - size of each element in the dictionary
    """
    dic = MiniBatchDictionaryLearning(n_components=n_c, alpha=a, n_iter=n_i,
                                      n_jobs=n_j, batch_size=b_s, fit_algorithm=fit_algorithm)
    print ("Start learning dictionary_mini: n_c: "+str(n_c)+", alpha: "+str(a)+", n_i: " +
           str(n_i)+", n_j: "+str(n_j)+", es: "+str(es)+", b_s: "+str(b_s))
    v1 = dic.fit(patches).components_
    d1 = v1.reshape(n_c, es, es, es)  # e.g. 512x5x5x5
    return d1


def learn_dictionary(patches, n_c=512, a=1, n_i=100, n_j=3, es=5, fit_algorithm='lars'):
    dic = DictionaryLearning(n_components=n_c, alpha=a, max_iter=n_i,
                             n_jobs=n_j, fit_algorithm=fit_algorithm)
    print ("Start learning dictionary: n_c: "+str(n_c)+", alpha: "+str(a)+", n_i: " +
           str(n_i)+", es: "+str(es)+", n_j: "+str(n_j))
    v2 = dic.fit(patches).components_
    d2 = v2.reshape(n_c, es, es, es)  # e.g. 512x5x5x5
    return d2


def serialize_dictionary(d, path2save):
    full_saving_path = path2save
    output = open(full_saving_path, 'wb')
    pickle.dump(d, output)
    output.close()
    print("saved at: "+full_saving_path)
    return None


def main():
    param = config.read_parameters()
    na = param.numOfAtoms
    es = param.eS
    bs = param.bS
    ni = param.nI
    ac = param.aC

    file_with_patches = param.path2patches+param.fileWithPatches+'.npy'
    patches = np.load(file_with_patches)
    # Learn the dictionary
    dictionary = learn_dictionary_mini(patches, n_c=na, a=ac, n_i=ni,
                                       n_j=4, b_s=bs, es=es, fit_algorithm='lars')
    # Serialize the dictionary
    path2save_dictionary = param.path2dicts+param.dictionaryName+'.pkl'
    serialize_dictionary(dictionary, path2save_dictionary)
    return None

if __name__ == '__main__':
    main()