from dask.distributed import Client

# TODO: Make this path absolute (using Pathlib for example).
WORKER_TEMPLATE_PATH = "worker-template.yaml"
ALLOWED_CLUSTER_TYPES = ["kubernetes", "local"]

# TODO: Improve this class (methods and properties)


class DaskScaler(object):

    client = None

    def __init__(self, cluster_type):
        if cluster_type not in ALLOWED_CLUSTER_TYPES:
            raise Exception("Can't choose this type of cluster for now. Choose one from: {}".format(
                ALLOWED_CLUSTER_TYPES))
        self.cluster_type = cluster_type
        self._cluster = self.get_cluster()
        self._client = Client(self._cluster)

    @staticmethod
    def _get_kubernetes_cluster(worker_template_path=WORKER_TEMPLATE_PATH):
        from dask_kubernetes import KubeCluster

        cluster = KubeCluster.from_yaml(worker_template_path)
        return Client(cluster)

    @staticmethod
    def _get_local_cluster():
        # TODO: Add more parameters and configurations.
        from distributed import LocalCluster
        return LocalCluster()

    def get_cluster(self):
        return getattr(self, "_get_" + self.cluster_type + "_cluster")

    @property
    def cluster(self):
        if self._cluster is None:
            self._cluster = self.get_cluster()
        return self._cluster

    @property
    def client(self):
        if self._client is None:
            self._client = Client(self._cluster)
        return self._client