from __future__ import print_function
from __future__ import division

from collections import OrderedDict
import os
import sys
import warnings

import argparse
import logging
import h5py as h5
import numpy as np
import pandas as pd

import json

import six
from six.moves import range

from deepcpg import data as dat
from deepcpg.data import annotations as an
from deepcpg import evaluation as ev
from deepcpg.data import stats
from deepcpg.data import dna
from deepcpg.data import fasta
from deepcpg.data import feature_extractor as fext
from deepcpg.utils import make_dir, to_list
from deepcpg.models.utils import decode_replicate_names, encode_replicate_names, get_sample_weights

from os import path as pt

from keras import backend as K
from keras import models as km
from keras import layers as kl
from keras.utils.np_utils import to_categorical
from deepcpg.data.dna import int_to_onehot

from kipoi.metadata import GenomicRanges
from kipoi.data import BatchIterator


def prepro_pos_table(pos_tables):
    """Extracts unique positions and sorts them."""
    if not isinstance(pos_tables, list):
        pos_tables = [pos_tables]

    pos_table = None
    for next_pos_table in pos_tables:
        if pos_table is None:
            pos_table = next_pos_table
        else:
            pos_table = pd.concat([pos_table, next_pos_table])
        pos_table = pos_table.groupby('chromo').apply(
            lambda df: pd.DataFrame({'pos': np.unique(df['pos'])}))
        pos_table.reset_index(inplace=True)
        pos_table = pos_table[['chromo', 'pos']]
        pos_table.sort_values(['chromo', 'pos'], inplace=True)
    return pos_table


def split_ext(filename):
    """Remove file extension from `filename`."""
    return os.path.basename(filename).split(os.extsep)[0]

def get_fh(filename, mode, *args, **kwargs):
    """ This function is only necessary because there is a bug in the 1.0.4 release version of DeepCpG.
    """
    is_gzip = filename.endswith('.gz')
    if is_gzip:
        return gzip.open(filename, mode, *args, **kwargs)
    else:
        return open(filename, mode, *args, **kwargs)


def read_cpg_profiles(filenames, log=None, *args, **kwargs):
    """Read methylation profiles.

    Input files can be gzip compressed.

    Returns
    -------
    dict
        `dict (key, value)`, where `key` is the output name and `value` the CpG
        table.
    """

    cpg_profiles = OrderedDict()
    for filename in filenames:
        if log:
            log(filename)
        #cpg_file = dat.GzipFile(filename, 'r')
        cpg_file = get_fh(filename, 'r')
        output_name = split_ext(filename)
        cpg_profile = dat.read_cpg_profile(cpg_file, sort=True, *args, **kwargs)
        cpg_profiles[output_name] = cpg_profile
        cpg_file.close()
    return cpg_profiles



def extract_seq_windows(seq, pos, wlen, seq_index=1, assert_cpg=False):
    """Extracts DNA sequence windows at positions.

    Parameters
    ----------
    seq: str
        DNA sequence.
    pos: list
        Positions at which windows are extracted.
    wlen: int
        Window length.
    seq_index: int
        Offset at which positions start.
    assert_cpg: bool
        If `True`, check if positions in `pos` point to CpG sites.

    Returns
    -------
    np.array
        Array with integer-encoded sequence windows.
    """

    delta = wlen // 2
    nb_win = len(pos)
    seq = seq.upper()
    seq_wins = np.zeros((nb_win, wlen), dtype='int8')

    for i in range(nb_win):
        p = pos[i] - seq_index
        if p < 0 or p >= len(seq):
            raise ValueError('Position %d not on chromosome!' % (p + seq_index))
        if seq[p:p + 2] != 'CG':
            warnings.warn('No CpG site at position %d!' % (p + seq_index))
        win = seq[max(0, p - delta): min(len(seq), p + delta + 1)]
        if len(win) < wlen:
            win = max(0, delta - p) * 'N' + win
            win += max(0, p + delta + 1 - len(seq)) * 'N'
            assert len(win) == wlen
        seq_wins[i] = dna.char_to_int(win)
    # Randomly choose missing nucleotides
    idx = seq_wins == dna.CHAR_TO_INT['N']
    seq_wins[idx] = np.random.randint(0, 4, idx.sum())
    assert seq_wins.max() < 4
    if assert_cpg:
        assert np.all(seq_wins[:, delta] == 3)
        assert np.all(seq_wins[:, delta + 1] == 2)
    return seq_wins


def map_values(values, pos, target_pos, dtype=None, nan=dat.CPG_NAN):
    """Maps `values` array at positions `pos` to `target_pos`.

    Inserts `nan` for uncovered positions.
    """
    assert len(values) == len(pos)
    assert np.all(pos == np.sort(pos))
    assert np.all(target_pos == np.sort(target_pos))

    values = values.ravel()
    pos = pos.ravel()
    target_pos = target_pos.ravel()
    idx = np.in1d(pos, target_pos)
    pos = pos[idx]
    values = values[idx]
    if not dtype:
        dtype = values.dtype
    target_values = np.empty(len(target_pos), dtype=dtype)
    target_values.fill(nan)
    idx = np.in1d(target_pos, pos).nonzero()[0]
    assert len(idx) == len(values)
    assert np.all(target_pos[idx] == pos)
    target_values[idx] = values
    return target_values


def map_cpg_tables(cpg_tables, chromo, chromo_pos):
    """Maps values from cpg_tables to `chromo_pos`.

    Positions in `cpg_tables` for `chromo`  must be a subset of `chromo_pos`.
    Inserts `dat.CPG_NAN` for uncovered positions.
    """
    chromo_pos.sort()
    mapped_tables = OrderedDict()
    for name, cpg_table in six.iteritems(cpg_tables):
        cpg_table = cpg_table.loc[cpg_table.chromo == chromo]
        cpg_table = cpg_table.sort_values('pos')
        mapped_table = map_values(cpg_table.value.values,
                                  cpg_table.pos.values,
                                  chromo_pos)
        assert len(mapped_table) == len(chromo_pos)
        mapped_tables[name] = mapped_table
    return mapped_tables


def format_out_of(out, of):
    return '%d / %d (%.1f%%)' % (out, of, out / of * 100)

def select_dict(data, idx):
    data = data.copy()
    for key, value in six.iteritems(data):
        if isinstance(value, dict):
            data[key] = select_dict(value, idx)
        else:
            data[key] = value[idx]
    return data


def annotate(anno_file, chromo, pos):
    #anno_file = dat.GzipFile(anno_file, 'r')
    anno_file = get_fh(anno_file, 'r')
    anno = pd.read_table(anno_file, header=None, usecols=[0, 1, 2],
                         dtype={0: 'str', 1: 'int32', 2: 'int32'})
    anno_file.close()
    anno.columns = ['chromo', 'start', 'end']
    anno.chromo = anno.chromo.str.upper().str.replace('CHR', '')
    anno = anno.loc[anno.chromo == chromo]
    anno.sort_values('start', inplace=True)
    start, end = an.join_overlapping(anno.start.values, anno.end.values)
    anno = np.array(an.is_in(pos, start, end), dtype='int8')
    return anno


def flatten_dict(obj,output_dict, prefix="", no_prefix = False):
    prefix = prefix.rstrip("/")
    assert (isinstance(obj, dict))
    for k in obj:
        local_prefix = ""
        if not no_prefix:
            local_prefix = prefix + "/"
        local_prefix +=  str(k)
        if isinstance(obj[k], dict):
            flatten_dict(obj[k], output_dict, local_prefix, no_prefix = False)
        else:
            output_dict[local_prefix] = obj[k]



def run_dcpg_data(pos_file = None,
                    cpg_profiles = None,
                    dna_files = None,
                    cpg_wlen=None,
                    cpg_cov = 1,
                    dna_wlen=1001,
                    anno_files=None,
                    chromos = None,
                    nb_sample = None,
                    nb_sample_chromo = None,
                    chunk_size = 32768,
                    seed = 0,
                    verbose = False):
    if seed is not None:
        np.random.seed(seed)


    # FIXME
    name = "dcpg_data"
    logging.basicConfig(format='%(levelname)s (%(asctime)s): %(message)s')
    log = logging.getLogger(name)
    if verbose:
        log.setLevel(logging.DEBUG)
    else:
        log.setLevel(logging.INFO)

    # Check input arguments
    if not cpg_profiles:
        if not (pos_file or dna_files):
            raise ValueError('Position table and DNA database expected!')

    if dna_wlen and dna_wlen % 2 == 0:
        raise 'dna_wlen must be odd!'
    if cpg_wlen and cpg_wlen % 2 != 0:
        raise 'cpg_wlen must be even!'

    """
    # Parse functions for computing output statistics
    cpg_stats_meta = None
    win_stats_meta = None
    if cpg_stats:
        cpg_stats_meta = get_stats_meta(cpg_stats)
    if win_stats:
        win_stats_meta = get_stats_meta(win_stats)
    """

    outputs = OrderedDict()

    # Read single-cell profiles if provided
    if cpg_profiles:
        log.info('Reading CpG profiles ...')
        outputs['cpg'] = read_cpg_profiles(
            cpg_profiles,
            chromos=chromos,
            nb_sample=nb_sample,
            nb_sample_chromo=nb_sample_chromo,
            log=log.info)

    # Create table with unique positions
    if pos_file:
        # Read positions from file
        log.info('Reading position table ...')
        pos_table = pd.read_table(pos_file, usecols=[0, 1],
                                  dtype={0: str, 1: np.int32},
                                  header=None, comment='#')
        pos_table.columns = ['chromo', 'pos']
        pos_table['chromo'] = dat.format_chromo(pos_table['chromo'])
        pos_table = prepro_pos_table(pos_table)
    else:
        # Extract positions from profiles
        pos_tables = []
        for cpg_table in list(outputs['cpg'].values()):
            pos_tables.append(cpg_table[['chromo', 'pos']])
        pos_table = prepro_pos_table(pos_tables)

    if chromos:
        pos_table = pos_table.loc[pos_table.chromo.isin(chromos)]
    if nb_sample_chromo:
        pos_table = dat.sample_from_chromo(pos_table, nb_sample_chromo)
    if nb_sample:
        pos_table = pos_table.iloc[:nb_sample]

    log.info('%d samples' % len(pos_table))


    # Iterate over chromosomes
    # ------------------------
    for chromo in pos_table.chromo.unique():
        log.info('-' * 80)
        log.info('Chromosome %s ...' % (chromo))
        idx = pos_table.chromo == chromo
        chromo_pos = pos_table.loc[idx].pos.values
        chromo_outputs = OrderedDict()

        if 'cpg' in outputs:
            # Concatenate CpG tables into single nb_site x nb_output matrix
            chromo_outputs['cpg'] = map_cpg_tables(outputs['cpg'],
                                                   chromo, chromo_pos)
            chromo_outputs['cpg_mat'] = np.vstack(
                list(chromo_outputs['cpg'].values())).T
            assert len(chromo_outputs['cpg_mat']) == len(chromo_pos)

        if 'cpg_mat' in chromo_outputs and cpg_cov:
            cov = np.sum(chromo_outputs['cpg_mat'] != dat.CPG_NAN, axis=1)
            assert np.all(cov >= 1)
            idx = cov >= cpg_cov
            tmp = '%s sites matched minimum coverage filter'
            tmp %= format_out_of(idx.sum(), len(idx))
            log.info(tmp)
            if idx.sum() == 0:
                continue

            chromo_pos = chromo_pos[idx]
            chromo_outputs = select_dict(chromo_outputs, idx)

        # Read DNA of chromosome
        chromo_dna = None
        if dna_files:
            chromo_dna = fasta.read_chromo(dna_files, chromo)

        annos = None
        if anno_files:
            log.info('Annotating CpG sites ...')
            annos = dict()
            for anno_file in anno_files:
                name = split_ext(anno_file)
                annos[name] = annotate(anno_file, chromo, chromo_pos)

        # Iterate over chunks
        # -------------------
        nb_chunk = int(np.ceil(len(chromo_pos) / chunk_size))
        for chunk in range(nb_chunk):
            log.info('Chunk \t%d / %d' % (chunk + 1, nb_chunk))
            chunk_start = chunk * chunk_size
            chunk_end = min(len(chromo_pos), chunk_start + chunk_size)
            chunk_idx = slice(chunk_start, chunk_end)
            chunk_pos = chromo_pos[chunk_idx]

            chunk_outputs = select_dict(chromo_outputs, chunk_idx)

            #filename = 'c%s_%06d-%06d.h5' % (chromo, chunk_start, chunk_end)
            #filename = os.path.join(out_dir, filename)
            #chunk_file = h5.File(filename, 'w')

            # Write positions
            #chunk_file.create_dataset('chromo', shape=(len(chunk_pos),),
            #                          dtype='S2')
            #chunk_file['chromo'][:] = chromo.encode()
            #chunk_file.create_dataset('pos', data=chunk_pos, dtype=np.int32)

            yield_dict = {}

            yield_dict["chromo"] = np.array([chromo.encode()]*len(chunk_pos), dtype='S2')
            yield_dict["pos"] = np.array(chunk_pos, dtype=np.int32)


            if len(chunk_outputs):
                #out_group = chunk_file.create_group('outputs')
                yield_dict["outputs"] = {}
                out_group = yield_dict["outputs"]


            # Write cpg profiles
            if 'cpg' in chunk_outputs:
                yield_dict["outputs"]['cpg']={}
                for name, value in six.iteritems(chunk_outputs['cpg']):
                    assert len(value) == len(chunk_pos)
                    # Round continuous values
                    #out_group.create_dataset('cpg/%s' % name,
                    #                         data=value.round(),
                    #                         dtype=np.int8,
                    #                         compression='gzip')
                    out_group['cpg'][name] = np.array(value.round(), np.int8)
                """
                # Compute and write statistics
                if cpg_stats_meta is not None:
                    log.info('Computing per CpG statistics ...')
                    cpg_mat = np.ma.masked_values(chunk_outputs['cpg_mat'],
                                                  dat.CPG_NAN)
                    mask = np.sum(~cpg_mat.mask, axis=1)
                    mask = mask < cpg_stats_cov
                    for name, fun in six.iteritems(cpg_stats_meta):
                        stat = fun[0](cpg_mat).data.astype(fun[1])
                        stat[mask] = dat.CPG_NAN
                        assert len(stat) == len(chunk_pos)
                        out_group.create_dataset('cpg_stats/%s' % name,
                                                 data=stat,
                                                 dtype=fun[1],
                                                 compression='gzip')
                """

            # Write input features
            #in_group = chunk_file.create_group('inputs')
            yield_dict["inputs"] = {}
            in_group = yield_dict["inputs"]

            # DNA windows
            if chromo_dna:
                log.info('Extracting DNA sequence windows ...')
                dna_wins = extract_seq_windows(chromo_dna, pos=chunk_pos,
                                               wlen=dna_wlen)
                assert len(dna_wins) == len(chunk_pos)
                #in_group.create_dataset('dna', data=dna_wins, dtype=np.int8,
                #                        compression='gzip')
                in_group['dna'] = np.array(dna_wins, dtype=np.int8)

            # CpG neighbors
            if cpg_wlen:
                log.info('Extracting CpG neighbors ...')
                cpg_ext = fext.KnnCpgFeatureExtractor(cpg_wlen // 2)
                #context_group = in_group.create_group('cpg')
                in_group['cpg'] = {}
                context_group = in_group['cpg']
                # outputs['cpg'], since neighboring CpG sites might lie
                # outside chunk borders and un-mapped values are needed
                for name, cpg_table in six.iteritems(outputs['cpg']):
                    cpg_table = cpg_table.loc[cpg_table.chromo == chromo]
                    state, dist = cpg_ext.extract(chunk_pos,
                                                  cpg_table.pos.values,
                                                  cpg_table.value.values)
                    nan = np.isnan(state)
                    state[nan] = dat.CPG_NAN
                    dist[nan] = dat.CPG_NAN
                    # States can be binary (np.int8) or continuous
                    # (np.float32).
                    state = state.astype(cpg_table.value.dtype, copy=False)
                    dist = dist.astype(np.float32, copy=False)

                    assert len(state) == len(chunk_pos)
                    assert len(dist) == len(chunk_pos)
                    assert np.all((dist > 0) | (dist == dat.CPG_NAN))

                    #group = context_group.create_group(name)
                    #group.create_dataset('state', data=state,
                    #                     compression='gzip')
                    #group.create_dataset('dist', data=dist,
                    #                     compression='gzip')
                    context_group[name] = {'state': state, 'dist':dist}

            """
            if win_stats_meta is not None and cpg_wlen:
                log.info('Computing window-based statistics ...')
                states = []
                dists = []
                cpg_states = []
                cpg_group = out_group['cpg']
                context_group = in_group['cpg']
                for output_name in six.iterkeys(cpg_group):
                    state = context_group[output_name]['state']#.value
                    states.append(np.expand_dims(state, 2))
                    dist = context_group[output_name]['dist']#.value
                    dists.append(np.expand_dims(dist, 2))
                    #cpg_states.append(cpg_group[output_name].value)
                    cpg_states.append(cpg_group[output_name])
                # samples x outputs x cpg_wlen
                states = np.swapaxes(np.concatenate(states, axis=2), 1, 2)
                dists = np.swapaxes(np.concatenate(dists, axis=2), 1, 2)
                cpg_states = np.expand_dims(np.vstack(cpg_states).T, 2)
                cpg_dists = np.zeros_like(cpg_states)
                states = np.concatenate([states, cpg_states], axis=2)
                dists = np.concatenate([dists, cpg_dists], axis=2)

                for wlen in win_stats_wlen:
                    idx = (states == dat.CPG_NAN) | (dists > wlen // 2)
                    states_wlen = np.ma.masked_array(states, idx)
                    group = out_group.create_group('win_stats/%d' % wlen)
                    for name, fun in six.iteritems(win_stats_meta):
                        stat = fun[0](states_wlen)
                        if hasattr(stat, 'mask'):
                            idx = stat.mask
                            stat = stat.data
                            if np.sum(idx):
                                stat[idx] = dat.CPG_NAN
                        group.create_dataset(name, data=stat, dtype=fun[1],
                                             compression='gzip')

            if annos:
                log.info('Adding annotations ...')
                group = in_group.create_group('annos')
                for name, anno in six.iteritems(annos):
                    group.create_dataset(name, data=anno[chunk_idx],
                                         dtype='int8',
                                         compression='gzip')
            """

            #chunk_file.close()

            flat_dict={}
            flatten_dict(yield_dict, flat_dict, no_prefix = True)
            yield flat_dict

    log.info('Done preprocessing!')




##### This function is needed to extract info on model architecture so that the output can be generated correctly.
def data_reader_config_from_model(model, config_out_fpath = None, replicate_names=None):
    """Return :class:`DataReader` from `model`.
    Builds a :class:`DataReader` for reading data for `model`.
    Parameters
    ----------
    model: :class:`Model`.
        :class:`Model`.
    outputs: bool
        If `True`, return output labels.
    replicate_names: list
        Name of input cells of `model`.
    Returns
    -------
    :class:`DataReader`
        Instance of :class:`DataReader`.
    """
    use_dna = False
    dna_wlen = None
    cpg_wlen = None
    output_names = None
    encode_replicates = False
    #
    input_shapes = to_list(model.input_shape)
    for input_name, input_shape in zip(model.input_names, input_shapes):
        if input_name == 'dna':
            # Read DNA sequences.
            use_dna = True
            dna_wlen = input_shape[1]
        elif input_name.startswith('cpg/state/'):
            # DEPRECATED: legacy model. Decode replicate names from input name.
            replicate_names = decode_replicate_names(input_name.replace('cpg/state/', ''))
            assert len(replicate_names) == input_shape[1]
            cpg_wlen = input_shape[2]
            encode_replicates = True
        elif input_name == 'cpg/state':
            # Read neighboring CpG sites.
            if not replicate_names:
                raise ValueError('Replicate names required!')
            if len(replicate_names) != input_shape[1]:
                tmp = '{r} replicates found but CpG model was trained with' \
                    ' {s} replicates. Use `--nb_replicate {s}` or ' \
                    ' `--replicate_names` option to select {s} replicates!'
                tmp = tmp.format(r=len(replicate_names), s=input_shape[1])
                raise ValueError(tmp)
            cpg_wlen = input_shape[2]
    output_names = model.output_names
    config = {"output_names":output_names,
                      "use_dna":use_dna,
                      "dna_wlen":dna_wlen,
                      "cpg_wlen":cpg_wlen,
                      "replicate_names":replicate_names,
                      "encode_replicates":encode_replicates}
    if config_out_fpath is not None:
        with open(config_out_fpath, "w") as ofh:
            json.dump(config, ofh)
    return config


def data_reader_from_config(config_fpath, outputs = True):
    with open(config_fpath, "r") as ifh:
        dr_kwargs = json.load(ifh)

    if not outputs:
        dr_kwargs["output_names"] = None
    
    return DataReader(**dr_kwargs)



class DataReader(object):
    """Read data from `dcpg_data.py` output files.
    Generator to read data batches from `dcpg_data.py` output files. Reads data
    using :func:`hdf.reader` and pre-processes data.
    Parameters
    ----------
    output_names: list
        Names of outputs to be read.
    use_dna: bool
        If `True`, read DNA sequence windows.
    dna_wlen: int
        Maximum length of DNA sequence windows.
    replicate_names: list
        Name of cells (profiles) whose neighboring CpG sites are read.
    cpg_wlen: int
        Maximum number of neighboring CpG sites.
    cpg_max_dist: int
        Value to threshold the distance of neighboring CpG sites.
    encode_replicates: bool
        If `True`, encode replicated names in key of returned dict. This option
        is deprecated and will be removed in the future.
    Returns
    -------
    tuple
        `dict` (`inputs`, `outputs`, `weights`), where `inputs`, `outputs`,
        `weights` is a `dict` of model inputs, outputs, and output weights.
        `outputs` and `weights` are not returned if `output_names` is undefined.
    """
    def __init__(self, output_names=None,
                 use_dna=True, dna_wlen=None,
                 replicate_names=None, cpg_wlen=None, cpg_max_dist=25000,
                 encode_replicates=False):
        self.output_names = to_list(output_names)
        self.use_dna = use_dna
        self.dna_wlen = dna_wlen
        self.replicate_names = to_list(replicate_names)
        self.cpg_wlen = cpg_wlen
        self.cpg_max_dist = cpg_max_dist
        self.encode_replicates = encode_replicates

    def _prepro_dna(self, dna):
        """Preprocess DNA sequence windows."""
        if self.dna_wlen:
            cur_wlen = dna.shape[1]
            center = cur_wlen // 2
            delta = self.dna_wlen // 2
            dna = dna[:, (center - delta):(center + delta + 1)]
        return int_to_onehot(dna)

    def _prepro_cpg(self, states, dists):
        """Preprocess the state and distance of neighboring CpG sites."""
        prepro_states = []
        prepro_dists = []
        for state, dist in zip(states, dists):
            nan = state == dat.CPG_NAN
            if np.any(nan):
                state[nan] = np.random.binomial(1, state[~nan].mean(),
                                                nan.sum())
                dist[nan] = self.cpg_max_dist
            dist = np.minimum(dist, self.cpg_max_dist) / self.cpg_max_dist
            prepro_states.append(np.expand_dims(state, 1))
            prepro_dists.append(np.expand_dims(dist, 1))
        prepro_states = np.concatenate(prepro_states, axis=1)
        prepro_dists = np.concatenate(prepro_dists, axis=1)
        if self.cpg_wlen:
            center = prepro_states.shape[2] // 2
            delta = self.cpg_wlen // 2
            tmp = slice(center - delta, center + delta)
            prepro_states = prepro_states[:, :, tmp]
            prepro_dists = prepro_dists[:, :, tmp]
        return (prepro_states, prepro_dists)


    def __call__(self, dcpg_data_kwargs, class_weights=None):
        """Return generator for reading data from `data_files`.
        Parameters
        ----------
        class_weights: dict
            dict of dict with class weights of individual outputs.
        *args: list
            Unnamed arguments passed to :func:`hdf.reader`
        *kwargs: dict
            Named arguments passed to :func:`hdf.reader`
        Returns
        -------
        generator
            Python generator for reading data.
        """
        names = []
        if self.use_dna:
            names.append('inputs/dna')

        if self.replicate_names:
            for name in self.replicate_names:
                names.append('inputs/cpg/%s/state' % name)
                names.append('inputs/cpg/%s/dist' % name)

        if self.output_names:
            for name in self.output_names:
                names.append('outputs/%s' % name)

        # check that the kwargs fit the model:
        if self.dna_wlen is not None:
            if ("dna_wlen" in dcpg_data_kwargs) and (dcpg_data_kwargs["dna_wlen"] != self.dna_wlen):
                log.warn("dna_wlen does not match requirements of the model (%d)"%self.dna_wlen)
            dcpg_data_kwargs["dna_wlen"] = self.dna_wlen

        if self.cpg_wlen is not None:
            if ("cpg_wlen" in dcpg_data_kwargs) and (dcpg_data_kwargs["cpg_wlen"] != self.cpg_wlen):
                log.warn("cpg_wlen does not match requirements of the model (%d)"%self.cpg_wlen)
            dcpg_data_kwargs["cpg_wlen"] = self.cpg_wlen

        ### Here insert the calling of run_dcpg_data(), require reformatting of the output
        data_iter = run_dcpg_data(**dcpg_data_kwargs)
        id_ctr_offset = 0
        for data_raw in data_iter:
            for k in names:
                if k not in data_raw:
                    raise ValueError('%s does not exist! Sample mismatch between model and input data?' % k)
            inputs = dict()

            if self.use_dna:
                inputs['dna'] = self._prepro_dna(data_raw['inputs/dna'])

            if self.replicate_names:
                states = []
                dists = []
                for name in self.replicate_names:
                    tmp = 'inputs/cpg/%s/' % name
                    states.append(data_raw[tmp + 'state'])
                    dists.append(data_raw[tmp + 'dist'])
                states, dists = self._prepro_cpg(states, dists)
                if self.encode_replicates:
                    # DEPRECATED: to support loading data for legacy models
                    tmp = '/' + encode_replicate_names(self.replicate_names)
                else:
                    tmp = ''
                inputs['cpg/state%s' % tmp] = states
                inputs['cpg/dist%s' % tmp] = dists

            outputs = dict()
            weights = dict()
            if not self.output_names:
                #yield inputs
                pass
            else:
                for name in self.output_names:
                    outputs[name] = data_raw['outputs/%s' % name]
                    cweights = class_weights[name] if class_weights else None
                    weights[name] = get_sample_weights(outputs[name], cweights)
                    if name.endswith('cat_var'):
                        output = outputs[name]
                        outputs[name] = to_categorical(output, 3)
                        outputs[name][output == dat.CPG_NAN] = 0

                #yield (inputs, outputs, weights)
            meta_data = {}
            # metadata is only generated if the respective window length is given
            if ("dna_wlen" in dcpg_data_kwargs) and (dcpg_data_kwargs["dna_wlen"] is not None):
                wlen = dcpg_data_kwargs["dna_wlen"]
                delta_pos = wlen // 2
                chrom = data_raw["chromo"].astype(str)
                start = data_raw["pos"] - delta_pos
                end = data_raw["pos"] + delta_pos + 1
                meta_data["dna_ranges"] = GenomicRanges(chrom, start, end, np.arange(chrom.shape[0])+id_ctr_offset)
            
            if ("cpg_wlen" in dcpg_data_kwargs) and (dcpg_data_kwargs["cpg_wlen"] is not None):
                wlen = dcpg_data_kwargs["cpg_wlen"]
                delta_pos = wlen // 2
                chrom = data_raw["chromo"].astype(str)
                start = data_raw["pos"] - delta_pos
                end = data_raw["pos"] + delta_pos + 1
                meta_data["cpg_ranges"] = GenomicRanges(chrom, start, end, np.arange(chrom.shape[0])+id_ctr_offset)

            id_ctr_offset += data_raw["chromo"].shape[0]
            # Weights are not supported at the moment 
            yield {"inputs": inputs, "targets":outputs, "metadata":meta_data}


class Dataloader(BatchIterator):
    def __init__(self, cpg_profiles, reference_fpath, batch_size = 100, outputs = True,
                class_weights=None):
        # derive the config file path from the path of the dataloader_m.py file:
        config_fpath = os.path.dirname(os.path.realpath(__file__)) + "/model_config.json"
        # compile arguments:
        assert isinstance(cpg_profiles, list)
        dcpg_data_kwargs = {"cpg_profiles": cpg_profiles,
                    "dna_files" : [reference_fpath],
                    "dna_wlen":1001,
                    "cpg_wlen":50,
                    "chunk_size" : batch_size} # chunksize === batch_size in current setup!
        self.dr_iter_obj = data_reader_from_config(config_fpath, outputs)(dcpg_data_kwargs, class_weights)

    def __next__(self):
        return self.dr_iter_obj.__next__()

    def __iter__(self):
        return self.dr_iter_obj