# -*- coding: utf-8 -*-
from multiprocessing import RawValue, RawArray, Semaphore, Lock
import ctypes
import numpy as np
import tensorflow as tf


class SharedCounter(object):
    def __init__(self, initval=0):
        self.val = RawValue('i', initval)
        self.last_step_update_target = RawValue('i', initval)
        self.lock = Lock()

    def increment(self, elapsed_steps=None):
        self.val.value += 1
        if ((elapsed_steps is not None) 
            and ((self.val.value - self.last_step_update_target.value) 
                >= elapsed_steps)):
            self.last_step_update_target.value = self.val.value
            return self.val.value, True
        else:
            return self.val.value, False

    def set_value(self, value):
        self.lock.acquire()
        self.val.value = value
        self.lock.release()

    def value(self):
        return self.val.value

class Barrier:
    def __init__(self, n):
        self.n = n
        self.counter = SharedCounter(0)
        self.barrier = Semaphore(0)

    def wait(self):
        with self.counter.lock:
            self.counter.val.value += 1
            if self.counter.val.value == self.n:
                self.barrier.release()

        self.barrier.acquire()
        self.barrier.release()

class SharedVars(object):
    def __init__(self, params, opt_type=None, lr=0, step=0):
        self.var_shapes = [
            var.get_shape().as_list()
            for var in params]
        self.size = sum([np.prod(shape) for shape in self.var_shapes])
        self.step = RawValue(ctypes.c_int, step)

        if opt_type == 'adam':
            self.ms = self.malloc_contiguous(self.size)
            self.vs = self.malloc_contiguous(self.size)
            self.lr = RawValue(ctypes.c_float, lr)
        elif opt_type == 'adamax':
            self.ms = self.malloc_contiguous(self.size)
            self.vs = self.malloc_contiguous(self.size)
            self.lr = RawValue(ctypes.c_float, lr)
        elif opt_type == 'rmsprop':
            self.vars = self.malloc_contiguous(self.size, np.ones(self.size, dtype=np.float))
        elif opt_type == 'momentum':
            self.vars = self.malloc_contiguous(self.size)
        else:
            self.vars = self.malloc_contiguous(self.size)

            
    def malloc_contiguous(self, size, initial_val=None):
        if initial_val is None:
            return RawArray(ctypes.c_float, size)
        else:
            return RawArray(ctypes.c_float, initial_val)


class SharedFlags(object):
    def __init__(self, num_actors):
        self.updated = RawArray(ctypes.c_int, num_actors)