# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple smoke test that runs these examples for 1 training iteraton."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import pandas as pd

from six.moves import StringIO

import iris_data
import custom_estimator
import premade_estimator

FOUR_LINES = "\n".join([
    "1,52.40, 2823,152,2",
    "164, 99.80,176.60,66.20,1",
    "176,2824, 136,3.19,0",
    "2,177.30,66.30, 53.10,1",])

def four_lines_data():
  text = StringIO(FOUR_LINES)

  df = pd.read_csv(text, names=iris_data.CSV_COLUMN_NAMES)

  xy = (df, df.pop("Species"))
  return xy, xy


class RegressionTest(tf.test.TestCase):
  """Test the regression examples in this directory."""

  @tf.test.mock.patch.dict(premade_estimator.__dict__,
                           {"load_data": four_lines_data})
  def test_premade_estimator(self):
    premade_estimator.main([None, "--train_steps=1"])

  @tf.test.mock.patch.dict(custom_estimator.__dict__,
                           {"load_data": four_lines_data})
  def test_custom_estimator(self):
    custom_estimator.main([None, "--train_steps=1"])

if __name__ == "__main__":
  tf.test.main()