"""Reduce the count of features in a feature set"""

from typing import cast, Dict, List

from pyspark.ml.feature import VectorAssembler
from pyspark.sql import DataFrame, functions as F

import mjolnir.feature_engineering
import mjolnir.transform as mt


def explode_features(metadata: Dict) -> mt.Transformer:
    def transform(df: DataFrame) -> DataFrame:
        # We drop the features column when exploding, so we need to hold
        # onto the metadata somewhere else.
        metadata['input_feature_meta'] = df.schema['features'].metadata
        # While later code could guess what columns of the exploded dataframe were
        # in the input, based on the feature metadata, be more explicit and
        # keep track directly.
        metadata['default_cols'] = [x for x in df.columns if x != 'features']
        return mjolnir.feature_engineering.explode_features(df) \
            .drop('features')
    return transform


def select_features(
    wiki: str,
    num_features: int,
    metadata: Dict
) -> mt.Transformer:
    def transform(df: DataFrame) -> DataFrame:
        # Compute the "best" features, per some metric
        sc = df.sql_ctx.sparkSession.sparkContext
        features = metadata['input_feature_meta']['features']
        selected = mjolnir.feature_engineering.select_features(
            sc, df, features, num_features, algo='mrmr')
        metadata['wiki_features'][wiki] = selected

        # Rebuild the `features` col with only the selected features
        keep_cols = metadata['default_cols'] + selected
        df_selected = df.select(*keep_cols)
        assembler = VectorAssembler(
            inputCols=selected, outputCol='features')
        return assembler.transform(df_selected).drop(*selected)
    return transform


def attach_feature_metadata(metadata: Dict) -> mt.Transformer:
    def transform(df: DataFrame) -> DataFrame:
        feature_meta = dict(
            metadata['input_feature_meta'],
            wiki_features=metadata['wiki_features'])
        sc = df.sql_ctx.sparkSession.sparkContext
        return df.withColumn(
            'features',
            mjolnir.spark.add_meta(sc, F.col('features'), feature_meta))
    return transform


@mt.typed_transformer(mt.FeatureVectors, mt.FeatureVectors, __name__)
def transformer(
    df_label: DataFrame,
    temp_dir: str,
    wikis: List[str],
    num_features: int
) -> mt.Transformer:
    mt.check_schema(df_label, mt.LabeledQueryPage)

    # Hack to transfer metadata between transformations. This is populated in
    # time since `select_features` does direct computation of the features.
    metadata = cast(Dict, {'wiki_features': {}})

    return mt.seq_transform([
        mt.restrict_wikis(wikis),
        mt.join_labels(df_label),
        explode_features(metadata),
        mt.cache_to_disk(temp_dir, partition_by='wikiid'),
        mt.for_each_item('wikiid', wikis, lambda wiki: select_features(
            wiki, num_features, metadata)),
        attach_feature_metadata(metadata),
        # While we used the labels for selecting features, they are not part of the feature vectors.
        # Allow them to be joined with any other label set for export to training.
        lambda df: df.drop('cluster_id', 'label'),
        lambda df: df.repartition(200, 'wikiid', 'query'),
    ])