/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.accumulators;

import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.accumulators.DoubleCounter;
import org.apache.flink.api.common.accumulators.Histogram;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;

import org.junit.Assert;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * Test for the basic functionality of accumulators. We cannot test all different
 * kinds of plans here (iterative, etc.).
 *
 * <p>TODO Test conflict when different UDFs write to accumulator with same name
 * but with different type. The conflict will occur in JobManager while merging.
 */
@SuppressWarnings("serial")
public class AccumulatorITCase extends JavaProgramTestBase {

	private static final String INPUT = "one\n" + "two two\n" + "three three three\n";
	private static final String EXPECTED = "one 1\ntwo 2\nthree 3\n";

	private String dataPath;
	private String resultPath;

	private JobExecutionResult result;

	@Override
	protected void preSubmit() throws Exception {
		dataPath = createTempFile("datapoints.txt", INPUT);
		resultPath = getTempFilePath("result");
	}

	@Override
	protected void postSubmit() throws Exception {
		compareResultsByLinesInMemory(EXPECTED, resultPath);

		// Test accumulator results
		System.out.println("Accumulator results:");
		JobExecutionResult res = this.result;
		System.out.println(AccumulatorHelper.getResultsFormatted(res.getAllAccumulatorResults()));

		Assert.assertEquals(Integer.valueOf(3), res.getAccumulatorResult("num-lines"));
		Assert.assertEquals(Integer.valueOf(3), res.getIntCounterResult("num-lines"));

		Assert.assertEquals(Double.valueOf(getParallelism()), res.getAccumulatorResult("open-close-counter"));

		// Test histogram (words per line distribution)
		Map<Integer, Integer> dist = new HashMap<>();
		dist.put(1, 1); dist.put(2, 1); dist.put(3, 1);
		Assert.assertEquals(dist, res.getAccumulatorResult("words-per-line"));

		// Test distinct words (custom accumulator)
		Set<StringValue> distinctWords = new HashSet<>();
		distinctWords.add(new StringValue("one"));
		distinctWords.add(new StringValue("two"));
		distinctWords.add(new StringValue("three"));
		Assert.assertEquals(distinctWords, res.getAccumulatorResult("distinct-words"));
	}

	@Override
	protected void testProgram() throws Exception {
		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

		DataSet<String> input = env.readTextFile(dataPath);

		input.flatMap(new TokenizeLine())
			.groupBy(0)
			.reduceGroup(new CountWords())
			.writeAsCsv(resultPath, "\n", " ");

		this.result = env.execute();
	}

	private static class TokenizeLine extends RichFlatMapFunction<String, Tuple2<String, Integer>> {

		// Needs to be instantiated later since the runtime context is not yet
		// initialized at this place
		private IntCounter cntNumLines;
		private Histogram wordsPerLineDistribution;

		// This counter will be added without convenience functions
		private DoubleCounter openCloseCounter = new DoubleCounter();
		private SetAccumulator<StringValue> distinctWords;

		@Override
		public void open(Configuration parameters) {

			// Add counters using convenience functions
			this.cntNumLines = getRuntimeContext().getIntCounter("num-lines");
			this.wordsPerLineDistribution = getRuntimeContext().getHistogram("words-per-line");

			// Add built-in accumulator without convenience function
			getRuntimeContext().addAccumulator("open-close-counter", this.openCloseCounter);

			// Add custom counter
			this.distinctWords = new SetAccumulator<>();
			this.getRuntimeContext().addAccumulator("distinct-words", distinctWords);

			// Create counter and test increment
			IntCounter simpleCounter = getRuntimeContext().getIntCounter("simple-counter");
			simpleCounter.add(1);
			Assert.assertEquals(simpleCounter.getLocalValue().intValue(), 1);

			// Test if we get the same counter
			IntCounter simpleCounter2 = getRuntimeContext().getIntCounter("simple-counter");
			Assert.assertEquals(simpleCounter.getLocalValue(), simpleCounter2.getLocalValue());

			// Should fail if we request it with different type
			try {
				@SuppressWarnings("unused")
				DoubleCounter simpleCounter3 = getRuntimeContext().getDoubleCounter("simple-counter");
				// DoubleSumAggregator longAggregator3 = (DoubleSumAggregator)
				// getRuntimeContext().getAggregator("custom",
				// DoubleSumAggregator.class);
				Assert.fail("Should not be able to obtain previously created counter with different type");
			}
			catch (UnsupportedOperationException ex) {
				// expected!
			}

			// Test counter used in open() and closed()
			this.openCloseCounter.add(0.5);
		}

		@Override
		public void flatMap(String value, Collector<Tuple2<String, Integer>> out) {
			this.cntNumLines.add(1);
			int wordsPerLine = 0;

			for (String token : value.toLowerCase().split("\\W+")) {
				distinctWords.add(new StringValue(token));
				out.collect(new Tuple2<>(token, 1));
				++wordsPerLine;
			}
			wordsPerLineDistribution.add(wordsPerLine);
		}

		@Override
		public void close() throws Exception {
			// Test counter used in open and close only
			this.openCloseCounter.add(0.5);
			Assert.assertEquals(1, this.openCloseCounter.getLocalValue().intValue());
		}
	}

	private static class CountWords
		extends RichGroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>
		implements GroupCombineFunction<Tuple2<String, Integer>, Tuple2<String, Integer>> {

		private IntCounter reduceCalls;
		private IntCounter combineCalls;

		@Override
		public void open(Configuration parameters) {
			this.reduceCalls = getRuntimeContext().getIntCounter("reduce-calls");
			this.combineCalls = getRuntimeContext().getIntCounter("combine-calls");
		}

		@Override
		public void reduce(Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
			reduceCalls.add(1);
			reduceInternal(values, out);
		}

		@Override
		public void combine(Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
			combineCalls.add(1);
			reduceInternal(values, out);
		}

		private void reduceInternal(Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
			int sum = 0;
			String key = null;

			for (Tuple2<String, Integer> e : values) {
				key = e.f0;
				sum += e.f1;
			}
			out.collect(new Tuple2<>(key, sum));
		}
	}

	/**
	 * Custom accumulator.
	 */
	public static class SetAccumulator<T> implements Accumulator<T, HashSet<T>> {

		private static final long serialVersionUID = 1L;

		private HashSet<T> set = new HashSet<>();

		@Override
		public void add(T value) {
			this.set.add(value);
		}

		@Override
		public HashSet<T> getLocalValue() {
			return this.set;
		}

		@Override
		public void resetLocal() {
			this.set.clear();
		}

		@Override
		public void merge(Accumulator<T, HashSet<T>> other) {
			// build union
			this.set.addAll(other.getLocalValue());
		}

		@Override
		public Accumulator<T, HashSet<T>> clone() {
			SetAccumulator<T> result = new SetAccumulator<>();
			result.set.addAll(set);
			return result;
		}
	}
}