from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import sqlite3
import sys
from .base import database

if sys.version_info[0] >= 3:
    unicode = str


class PickalableSWIG:
    def __setstate__(self, state):
        self.__init__(*state['args'])
    def __getstate__(self):
        return {'args': self.args}


class PickalableSQL3Connect(sqlite3.Connection, PickalableSWIG):
    def __init__(self, *args,**kwargs):
        self.args = args
        sqlite3.Connection.__init__(self,*args,**kwargs)


class PickalableSQL3Cursor(sqlite3.Cursor, PickalableSWIG):
    def __init__(self, *args,**kwargs):
        self.args = args
        sqlite3.Cursor.__init__(self,*args,**kwargs)



class sql(database):
    """
    This class saves the process in the working storage. It can be used if
    safety matters.
    """

    def __init__(self, *args, **kwargs):
        import os
        # init base class
        super(sql, self).__init__(*args, **kwargs)
        # Create a open file, which needs to be closed after the sampling
        try:
            os.remove(self.dbname + '.db')
        except:
            pass

        self.db = PickalableSQL3Connect(self.dbname + '.db')
        self.db_cursor = PickalableSQL3Cursor(self.db)
        # Create Table
        #        self.db_cursor.execute('''CREATE TABLE IF NOT EXISTS  '''+self.dbname+'''
        #                     (like1 real, parx real, pary real, simulation1 real, chain int)''')
        self.db_cursor.execute('''CREATE TABLE IF NOT EXISTS  ''' + self.dbname + '''
                     (''' + ' real ,'.join(self.header) + ''')''')

    def save(self, objectivefunction, parameterlist, simulations=None, chains=1):
        coll = (self.dim_dict['like'](objectivefunction) +
                self.dim_dict['par'](parameterlist) +
                self.dim_dict['simulation'](simulations) +
                [chains])
        # Apply rounding of floats
        coll = map(self.db_precision, coll)
        self.db_cursor.execute(
            "INSERT INTO " + self.dbname + " VALUES (" + '"' + str('","'.join(map(str, coll))) + '"' + ")")

        self.db.commit()

    def finalize(self):
        self.db.close()

    def getdata(self):
        self.db = PickalableSQL3Connect(self.dbname + '.db')
        self.db_cursor = PickalableSQL3Cursor(self.db)

        if sys.version_info[0] >= 3:
            headers = [(row[1], "<f8") for row in
                       self.db_cursor.execute("PRAGMA table_info(" + self.dbname + ");")]
        else:
            # Workaround for python2
            headers = [(unicode(row[1]).encode("ascii"), unicode("<f8").encode("ascii")) for row in
                       self.db_cursor.execute("PRAGMA table_info(" + self.dbname + ");")]

        back = np.array([row for row in self.db_cursor.execute('SELECT * FROM ' + self.dbname)], dtype=headers)

        self.db.close()
        return back