import os 
import sys 
import argparse 
 
import _pickle as pickle
import numpy as np 
from sklearn.multiclass import OneVsOneClassifier 
from sklearn.svm import LinearSVC 
from sklearn import preprocessing 
 
# To train the classifier 
class ClassifierTrainer(object): 
    def __init__(self, X, label_words): 
        # Encoding the labels (words to numbers) 
        self.le = preprocessing.LabelEncoder() 
 
        # Initialize One vs One Classifier using a linear kernel 
        self.clf = OneVsOneClassifier(LinearSVC(random_state=0)) 
 
        y = self._encodeLabels(label_words) 
        X = np.asarray(X) 
        self.clf.fit(X, y) 
 
    # Predict the output class for the input datapoint 
    def _fit(self, X): 
        X = np.asarray(X) 
        return self.clf.predict(X) 
 
    # Encode the labels (convert words to numbers) 
    def _encodeLabels(self, labels_words): 
        self.le.fit(labels_words) 
        return np.array(self.le.transform(labels_words), 
         dtype=np.float32) 
 
    # Classify the input datapoint 
    def classify(self, X): 
        labels_nums = self._fit(X) 
        labels_words = self.le.inverse_transform([int(x) for x in 
         labels_nums]) 
        return labels_words 
 
def build_arg_parser(): 
    parser = argparse.ArgumentParser(description='Trains the classifier models')
    parser.add_argument("--feature-map-file", dest="feature_map_file", required=True,\
        help="Input pickle file containing the feature map") 
    parser.add_argument("--svm-file", dest="svm_file", required=False,\
        help="Output file where the pickled SVM model will be stored") 
    return parser 
    
if __name__=='__main__': 
    args = build_arg_parser().parse_args() 
    feature_map_file = args.feature_map_file 
    svm_file = args.svm_file 
 
    # Load the feature map 
    with open(feature_map_file, 'rb') as f: 
        feature_map = pickle.load(f) 
 
    # Extract feature vectors and the labels 
    labels_words = [x['label'] for x in feature_map] 
 
    # Here, 0 refers to the first element in the 
    # feature_map, and 1 refers to the second 
    # element in the shape vector of that element 
    # (which gives us the size) 
    dim_size = feature_map[0]['feature_vector'].shape[1] 
 
    X = [np.reshape(x['feature_vector'], (dim_size,)) for x in feature_map] 
 
    # Train the SVM 
    svm = ClassifierTrainer(X, labels_words) 
    if args.svm_file: 
        with open(args.svm_file, 'wb') as f: 
            pickle.dump(svm, f)