# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the preprocessing job to produce records for training."""

import argparse
import ConfigParser
import logging
import os
import sys

import apache_beam as beam

from preprocessing import preprocess


def _parse_arguments(argv):
    """Parses command line arguments."""
    parser = argparse.ArgumentParser(
        description='Runs preprocessing on census train data.')
    parser.add_argument(
        '--project_id',
        required=True,
        help='Name of the project.')
    parser.add_argument(
        '--job_name',
        required=False,
        help='Name of the dataflow job.')
    parser.add_argument(
        '--job_dir',
        required=True,
        help='Directory to write outputs.')
    parser.add_argument(
        '--cloud',
        default=False,
        action='store_true',
        help='Run preprocessing on the cloud.')
    parser.add_argument(
        '--input_data',
        required=True,
        help='Path to input data.')
    args, _ = parser.parse_known_args(args=argv[1:])
    return args


def _parse_config(env, config_file_path):
    """Parses configuration file.

    Args:
      env: The environment in which the preprocessing job will be run.
      config_file_path: Path to the configuration file to be parsed.

    Returns:
      A dictionary containing the parsed runtime config.
    """
    config = ConfigParser.ConfigParser()
    config.read(config_file_path)
    return dict(config.items(env))


def _set_logging(log_level):
    logging.getLogger().setLevel(getattr(logging, log_level.upper()))


def main():
    """Configures pipeline and spawns preprocessing job."""

    args = _parse_arguments(sys.argv)
    config_path = os.path.abspath(
        os.path.join(__file__, os.pardir, 'preprocessing_config.ini'))
    config = _parse_config('CLOUD' if args.cloud else 'LOCAL',
                           config_path)
    ml_project = args.project_id
    options = {'project': ml_project}

    if args.cloud:
        if not args.job_name:
            raise ValueError('Job name must be specified for cloud runs.')
        options.update({
            'job_name': args.job_name,
            'num_workers': int(config.get('num_workers')),
            'max_num_workers': int(config.get('max_num_workers')),
            'staging_location': os.path.join(args.job_dir, 'staging'),
            'temp_location': os.path.join(args.job_dir, 'tmp'),
            'region': config.get('region'),
            'setup_file': os.path.abspath(
                os.path.join(__file__, '../..', 'dataflow_setup.py')),
        })
    pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
    _set_logging(config.get('log_level'))

    with beam.Pipeline(config.get('runner'), options=pipeline_options) as p:
        preprocess.run(p, args.input_data, args.job_dir)


if __name__ == '__main__':
    main()