import collections
import multiprocessing as mp
import multiprocessing.pool
import functools

import numpy as np

from buzzard._actors.message import Msg
from buzzard._actors.pool_job import ProductionJobWaiting, PoolJobWorking

class ActorComputer(object):
    """Actor that takes care of sheduling computations by using user's `compute_array` function"""

    def __init__(self, raster):
        self._raster = raster
        self._alive = True
        computation_pool = raster.computation_pool
        if computation_pool is not None:
            self._waiting_room_address = '/Pool{}/WaitingRoom'.format(id(computation_pool))
            self._working_room_address = '/Pool{}/WorkingRoom'.format(id(computation_pool))
            if isinstance(computation_pool, mp.pool.ThreadPool):
                self._same_address_space = True
            elif isinstance(computation_pool, mp.pool.Pool):
                self._same_address_space = False
            else: # pragma: no cover
                assert False, 'Type should be checked in facade'
        self._waiting_jobs_per_query = collections.defaultdict(set)
        self._working_jobs = set()

        self._performed_computations = set() # type: Set[Footprint]
        self.address = '/Raster{}/Computer'.format(self._raster.uid)

    @property
    def alive(self):
        return self._alive

    # ******************************************************************************************* **
    def receive_compute_this_array(self, qi, compute_idx):
        """Receive message: Start making this array"""
        msgs = []

        if self._raster.computation_pool is None:
            work = self._create_work_job(qi, compute_idx)
            compute_fp = qi.cache_computation.list_of_compute_fp[compute_idx]
            if compute_fp not in self._performed_computations:
                res = work.func()
                res = self._normalize_user_result(compute_fp, res)
                self._raster.debug_mngr.event('object_allocated', res)
                self._performed_computations.add(compute_fp)
                msgs += self._commit_work_result(work, res)

        else:
            wait = Wait(self, qi, compute_idx)
            self._waiting_jobs_per_query[qi].add(wait)
            msgs += [Msg(self._waiting_room_address, 'schedule_job', wait)]

        return msgs

    def receive_token_to_working_room(self, job, token):
        msgs = []

        self._waiting_jobs_per_query[job.qi].remove(job)
        if len(self._waiting_jobs_per_query[job.qi]) == 0:
            del self._waiting_jobs_per_query[job.qi]

        work = self._create_work_job(job.qi, job.compute_idx)

        compute_fp = job.qi.cache_computation.list_of_compute_fp[job.compute_idx]
        if compute_fp not in self._performed_computations:
            msgs += [Msg(self._working_room_address, 'launch_job_with_token', work, token)]
            self._performed_computations.add(compute_fp)
            self._working_jobs.add(work)
        else:
            msgs += [Msg(self._working_room_address, 'salvage_token', token)]

        return msgs

    def receive_job_done(self, job, result):
        result = self._normalize_user_result(job.compute_fp, result)
        self._raster.debug_mngr.event('object_allocated', result)
        self._working_jobs.remove(job)
        return self._commit_work_result(job, result)

    def receive_cancel_this_query(self, qi):
        """Receive message: One query was dropped

        Parameters
        ----------
        qi: _actors.cached.query_infos.QueryInfos
        """
        msgs = []
        for job in self._waiting_jobs_per_query[qi]:
            msgs += [Msg(self._waiting_room_address, 'unschedule_job', job)]
        del self._waiting_jobs_per_query[qi]
        return msgs

    def receive_die(self):
        """Receive message: The raster was killed"""
        assert self._alive
        self._alive = False

        msgs = []
        msgs += [
            Msg(self._waiting_room_address, 'unschedule_job', job)
            for jobs in self._waiting_jobs_per_query.values()
            for job in jobs
        ]
        self._waiting_jobs_per_query.clear()

        msgs += [
            Msg(self._working_room_address, 'cancel_job', job)
            for job in self._working_jobs
        ]
        self._working_jobs.clear()

        self._raster = None
        return msgs

    # ******************************************************************************************* **
    def _create_work_job(self, qi, compute_idx):
        return Work(
            self, qi, compute_idx,
        )

    def _commit_work_result(self, work_job, res):
        return [Msg('ComputationAccumulator', 'combine_this_array', work_job.compute_fp, res)]

    def _normalize_user_result(self, compute_fp, res):
        if not isinstance(res, np.ndarray): # pragma: no cover
            raise ValueError("Result of recipe's `compute_array` have type {}, it should be ndarray".format(
                type(res)
            ))
        res = np.atleast_3d(res)
        y, x, c = res.shape
        if (y, x) != tuple(compute_fp.shape): # pragma: no cover
            raise ValueError("Result of recipe's `compute_array` have shape `{}`, should start with {}".format(
                res.shape,
                compute_fp.shape,
            ))
        if c != len(self._raster): # pragma: no cover
            raise ValueError("Result of recipe's `compute_array` have shape `{}`, should have {} bands".format(
                res.shape,
                len(self._raster),
            ))
        res = res.astype(self._raster.dtype, copy=False)
        return res

    # ******************************************************************************************* **

class Wait(ProductionJobWaiting):

    def __init__(self, actor, qi, compute_idx):
        self.qi = qi
        self.compute_idx = compute_idx
        qicc = qi.cache_computation

        compute_fp = qicc.list_of_compute_fp[compute_idx]
        prod_idx = qicc.dict_of_min_prod_idx_per_compute_fp[compute_fp]
        super().__init__(actor.address, qi, prod_idx, 4, compute_fp)

class Work(PoolJobWorking):
    def __init__(self, actor, qi, compute_idx):
        qicc = qi.cache_computation
        assert qicc.collected_count == compute_idx, (qicc.collected_count, compute_idx)

        compute_fp = qicc.list_of_compute_fp[compute_idx]

        self.compute_fp = compute_fp

        primitive_arrays = {}
        primitive_footprints = {}
        for prim_name, queue in qicc.primitive_queue_per_primitive.items():
            primitive_arrays[prim_name] = queue.get_nowait()
            primitive_footprints[prim_name] = qicc.primitive_fps_per_primitive[prim_name][compute_idx]

        qicc.collected_count += 1

        if actor._raster.computation_pool is None or actor._same_address_space:
            func = functools.partial(
                actor._raster.compute_array,
                compute_fp,
                primitive_footprints,
                primitive_arrays,
                actor._raster.facade_proxy
            )
        else:
            func = functools.partial(
                actor._raster.compute_array,
                compute_fp,
                primitive_footprints,
                primitive_arrays,
                None,
            )
        actor._raster.debug_mngr.event('object_allocated', func)

        super().__init__(actor.address, func)