import logging

import pyspark.sql.functions as F
from pyspark import SparkContext
from pyspark.sql import SQLContext
from snorkel.labeling.model import LabelModel
from snorkel.labeling.apply.spark import SparkLFApplier

from drybell_lfs_spark import (
    article_mentions_person,
    body_contains_fortune,
    person_in_db,
)

logging.basicConfig(level=logging.INFO)


def main(data_path, output_path):
    # Read data
    logging.info(f"Reading data from {data_path}")
    sc = SparkContext()
    sql = SQLContext(sc)
    data = sql.read.parquet(data_path)

    # Build label matrix
    logging.info("Applying LFs")
    lfs = [article_mentions_person, body_contains_fortune, person_in_db]
    applier = SparkLFApplier(lfs)
    L = applier.apply(data.rdd)

    # Train label model
    logging.info("Training label model")
    label_model = LabelModel(cardinality=2)
    label_model.fit(L)

    # Generate training labels
    logging.info("Generating probabilistic labels")
    y_prob = label_model.predict_proba(L)[:, 1]
    y_prob_sql_array = F.array([F.lit(y) for y in y_prob])
    data_labeled = data.withColumn("y_prob", y_prob_sql_array)
    data_labeled.write.mode("overwrite").parquet(output_path)
    logging.info(f"Labels saved to {output_path}")


if __name__ == "__main__":
    main("drybell/data/raw_data.parquet", "drybell/data/labeled_data_spark.parquet")