# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Make datasets and save specified directory.

Downloads datasets using scikit datasets and can also parse csv file
to save into pickle format.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from io import BytesIO
import os
import pickle
import StringIO
import tarfile
import urllib2

import keras.backend as K
from keras.datasets import cifar10
from keras.datasets import cifar100
from keras.datasets import mnist

import numpy as np
import pandas as pd
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.datasets import fetch_mldata
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
import sklearn.datasets.rcv1
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer

from absl import app
from absl import flags
from tensorflow import gfile

flags.DEFINE_string('save_dir', '/tmp/data',
                    'Where to save outputs')
flags.DEFINE_string('datasets', '',
                    'Which datasets to download, comma separated.')
FLAGS = flags.FLAGS


class Dataset(object):

  def __init__(self, X, y):
    self.data = X
    self.target = y


def get_csv_data(filename):
  """Parse csv and return Dataset object with data and targets.

  Create pickle data from csv, assumes the first column contains the targets
  Args:
    filename: complete path of the csv file
  Returns:
    Dataset object
  """
  f = gfile.GFile(filename, 'r')
  mat = []
  for l in f:
    row = l.strip()
    row = row.replace('"', '')
    row = row.split(',')
    row = [float(x) for x in row]
    mat.append(row)
  mat = np.array(mat)
  y = mat[:, 0]
  X = mat[:, 1:]
  data = Dataset(X, y)
  return data


def get_wikipedia_talk_data():
  """Get wikipedia talk dataset.

  See here for more information about the dataset:
  https://figshare.com/articles/Wikipedia_Detox_Data/4054689
  Downloads annotated comments and annotations.
  """

  ANNOTATED_COMMENTS_URL = 'https://ndownloader.figshare.com/files/7554634'
  ANNOTATIONS_URL = 'https://ndownloader.figshare.com/files/7554637'

  def download_file(url):
    req = urllib2.Request(url)
    response = urllib2.urlopen(req)
    return response

  # Process comments
  comments = pd.read_table(
      download_file(ANNOTATED_COMMENTS_URL), index_col=0, sep='\t')
  # remove newline and tab tokens
  comments['comment'] = comments['comment'].apply(
      lambda x: x.replace('NEWLINE_TOKEN', ' '))
  comments['comment'] = comments['comment'].apply(
      lambda x: x.replace('TAB_TOKEN', ' '))

  # Process labels
  annotations = pd.read_table(download_file(ANNOTATIONS_URL), sep='\t')
  # labels a comment as an atack if the majority of annoatators did so
  labels = annotations.groupby('rev_id')['attack'].mean() > 0.5

  # Perform data preprocessing, should probably tune these hyperparameters
  vect = CountVectorizer(max_features=30000, ngram_range=(1, 2))
  tfidf = TfidfTransformer(norm='l2')
  X = tfidf.fit_transform(vect.fit_transform(comments['comment']))
  y = np.array(labels)
  data = Dataset(X, y)
  return data


def get_keras_data(dataname):
  """Get datasets using keras API and return as a Dataset object."""
  if dataname == 'cifar10_keras':
    train, test = cifar10.load_data()
  elif dataname == 'cifar100_coarse_keras':
    train, test = cifar100.load_data('coarse')
  elif dataname == 'cifar100_keras':
    train, test = cifar100.load_data()
  elif dataname == 'mnist_keras':
    train, test = mnist.load_data()
  else:
    raise NotImplementedError('dataset not supported')

  X = np.concatenate((train[0], test[0]))
  y = np.concatenate((train[1], test[1]))

  if dataname == 'mnist_keras':
    # Add extra dimension for channel
    num_rows = X.shape[1]
    num_cols = X.shape[2]
    X = X.reshape(X.shape[0], 1, num_rows, num_cols)
    if K.image_data_format() == 'channels_last':
      X = X.transpose(0, 2, 3, 1)

  y = y.flatten()
  data = Dataset(X, y)
  return data


# TODO(lishal): remove regular cifar10 dataset and only use dataset downloaded
# from keras to maintain image dims to create tensor for tf models
# Requires adding handling in run_experiment.py for handling of different
# training methods that require either 2d or tensor data.
def get_cifar10():
  """Get CIFAR-10 dataset from source dir.

  Slightly redundant with keras function to get cifar10 but this returns
  in flat format instead of keras numpy image tensor.
  """
  url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
  def download_file(url):
    req = urllib2.Request(url)
    response = urllib2.urlopen(req)
    return response
  response = download_file(url)
  tmpfile = BytesIO()
  while True:
    # Download a piece of the file from the connection
    s = response.read(16384)
    # Once the entire file has been downloaded, tarfile returns b''
    # (the empty bytes) which is a falsey value
    if not s:
      break
    # Otherwise, write the piece of the file to the temporary file.
    tmpfile.write(s)
  response.close()

  tmpfile.seek(0)
  tar_dir = tarfile.open(mode='r:gz', fileobj=tmpfile)
  X = None
  y = None
  for member in tar_dir.getnames():
    if '_batch' in member:
      filestream = tar_dir.extractfile(member).read()
      batch = pickle.load(StringIO.StringIO(filestream))
      if X is None:
        X = np.array(batch['data'], dtype=np.uint8)
        y = np.array(batch['labels'])
      else:
        X = np.concatenate((X, np.array(batch['data'], dtype=np.uint8)))
        y = np.concatenate((y, np.array(batch['labels'])))
  data = Dataset(X, y)
  return data


def get_mldata(dataset):
  # Use scikit to grab datasets and save them save_dir.
  save_dir = FLAGS.save_dir
  filename = os.path.join(save_dir, dataset[1]+'.pkl')

  if not gfile.Exists(save_dir):
    gfile.MkDir(save_dir)
  if not gfile.Exists(filename):
    if dataset[0][-3:] == 'csv':
      data = get_csv_data(dataset[0])
    elif dataset[0] == 'breast_cancer':
      data = load_breast_cancer()
    elif dataset[0] == 'iris':
      data = load_iris()
    elif dataset[0] == 'newsgroup':
      # Removing header information to make sure that no newsgroup identifying
      # information is included in data
      data = fetch_20newsgroups_vectorized(subset='all', remove=('headers'))
      tfidf = TfidfTransformer(norm='l2')
      X = tfidf.fit_transform(data.data)
      data.data = X
    elif dataset[0] == 'rcv1':
      sklearn.datasets.rcv1.URL = (
        'http://www.ai.mit.edu/projects/jmlr/papers/'
        'volume5/lewis04a/a13-vector-files/lyrl2004_vectors')
      sklearn.datasets.rcv1.URL_topics = (
        'http://www.ai.mit.edu/projects/jmlr/papers/'
        'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz')
      data = sklearn.datasets.fetch_rcv1(
          data_home='/tmp')
    elif dataset[0] == 'wikipedia_attack':
      data = get_wikipedia_talk_data()
    elif dataset[0] == 'cifar10':
      data = get_cifar10()
    elif 'keras' in dataset[0]:
      data = get_keras_data(dataset[0])
    else:
      try:
        data = fetch_mldata(dataset[0])
      except:
        raise Exception('ERROR: failed to fetch data from mldata.org')
    X = data.data
    y = data.target
    if X.shape[0] != y.shape[0]:
      X = np.transpose(X)
    assert X.shape[0] == y.shape[0]

    data = {'data': X, 'target': y}
    pickle.dump(data, gfile.GFile(filename, 'w'))


def main(argv):
  del argv  # Unused.
  # First entry of tuple is mldata.org name, second is the name that we'll use
  # to reference the data.
  datasets = [('mnist (original)', 'mnist'), ('australian', 'australian'),
              ('heart', 'heart'), ('breast_cancer', 'breast_cancer'),
              ('iris', 'iris'), ('vehicle', 'vehicle'), ('wine', 'wine'),
              ('waveform ida', 'waveform'), ('german ida', 'german'),
              ('splice ida', 'splice'), ('ringnorm ida', 'ringnorm'),
              ('twonorm ida', 'twonorm'), ('diabetes_scale', 'diabetes'),
              ('mushrooms', 'mushrooms'), ('letter', 'letter'), ('dna', 'dna'),
              ('banana-ida', 'banana'), ('letter', 'letter'), ('dna', 'dna'),
              ('newsgroup', 'newsgroup'), ('cifar10', 'cifar10'),
              ('cifar10_keras', 'cifar10_keras'),
              ('cifar100_keras', 'cifar100_keras'),
              ('cifar100_coarse_keras', 'cifar100_coarse_keras'),
              ('mnist_keras', 'mnist_keras'),
              ('wikipedia_attack', 'wikipedia_attack'),
              ('rcv1', 'rcv1')]

  if FLAGS.datasets:
    subset = FLAGS.datasets.split(',')
    datasets = [d for d in datasets if d[1] in subset]

  for d in datasets:
    print(d[1])
    get_mldata(d)


if __name__ == '__main__':
  app.run(main)