/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.streaming.runtime.operators.windowing;

import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.MergingState;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.runtime.state.internal.InternalMergingState;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
import org.apache.flink.streaming.api.operators.InternalTimerService;
import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.TestInternalTimerService;
import org.apache.flink.streaming.api.windowing.triggers.Trigger;
import org.apache.flink.streaming.api.windowing.triggers.TriggerResult;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

/**
 * Utility for testing {@link Trigger} behaviour.
 */
public class TriggerTestHarness<T, W extends Window> {

	private static final Integer KEY = 1;

	private final Trigger<T, W> trigger;
	private final TypeSerializer<W> windowSerializer;

	private final HeapKeyedStateBackend<Integer> stateBackend;
	private final TestInternalTimerService<Integer, W> internalTimerService;

	public TriggerTestHarness(
			Trigger<T, W> trigger,
			TypeSerializer<W> windowSerializer) throws Exception {
		this.trigger = trigger;
		this.windowSerializer = windowSerializer;

		// we only ever use one key, other tests make sure that windows work across different
		// keys
		DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0);
		MemoryStateBackend backend = new MemoryStateBackend();

		@SuppressWarnings("unchecked")
		HeapKeyedStateBackend<Integer> stateBackend = (HeapKeyedStateBackend<Integer>) backend.createKeyedStateBackend(
			dummyEnv,
			new JobID(),
			"test_op",
			IntSerializer.INSTANCE,
			1,
			new KeyGroupRange(0, 0),
			new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()),
			TtlTimeProvider.DEFAULT,
			new UnregisteredMetricsGroup(),
			Collections.emptyList(),
			new CloseableRegistry());
		this.stateBackend = stateBackend;

		this.stateBackend.setCurrentKey(KEY);

		this.internalTimerService = new TestInternalTimerService<>(new KeyContext() {
			@Override
			public void setCurrentKey(Object key) {
				// ignore
			}

			@Override
			public Object getCurrentKey() {
				return KEY;
			}
		});
	}

	public int numProcessingTimeTimers() {
		return internalTimerService.numProcessingTimeTimers();
	}

	public int numProcessingTimeTimers(W window) {
		return internalTimerService.numProcessingTimeTimers(window);
	}

	public int numEventTimeTimers() {
		return internalTimerService.numEventTimeTimers();
	}

	public int numEventTimeTimers(W window) {
		return internalTimerService.numEventTimeTimers(window);
	}

	public int numStateEntries() {
		return stateBackend.numKeyValueStateEntries();
	}

	public int numStateEntries(W window) {
		return stateBackend.numKeyValueStateEntries(window);
	}

	/**
	 * Injects one element into the trigger for the given window and returns the result of
	 * {@link Trigger#onElement(Object, long, Window, Trigger.TriggerContext)}.
	 */
	public TriggerResult processElement(StreamRecord<T> element, W window) throws Exception {
		TestTriggerContext<Integer, W> triggerContext = new TestTriggerContext<>(
				KEY,
				window,
				internalTimerService,
				stateBackend,
				windowSerializer);
		return trigger.onElement(element.getValue(), element.getTimestamp(), window, triggerContext);
	}

	/**
	 * Advanced processing time and checks whether we have exactly one firing for the given
	 * window. The result of {@link Trigger#onProcessingTime(long, Window, Trigger.TriggerContext)}
	 * is returned for that firing.
	 */
	public TriggerResult advanceProcessingTime(long time, W window) throws Exception {
		Collection<Tuple2<W, TriggerResult>> firings = advanceProcessingTime(time);

		if (firings.size() != 1) {
			throw new IllegalStateException("Must have exactly one timer firing. Fired timers: " + firings);
		}

		Tuple2<W, TriggerResult> firing = firings.iterator().next();

		if (!firing.f0.equals(window)) {
			throw new IllegalStateException("Trigger fired for another window.");
		}

		return firing.f1;
	}

	/**
	 * Advanced the watermark and checks whether we have exactly one firing for the given
	 * window. The result of {@link Trigger#onEventTime(long, Window, Trigger.TriggerContext)}
	 * is returned for that firing.
	 */
	public TriggerResult advanceWatermark(long time, W window) throws Exception {
		Collection<Tuple2<W, TriggerResult>> firings = advanceWatermark(time);

		if (firings.size() != 1) {
			throw new IllegalStateException("Must have exactly one timer firing. Fired timers: " + firings);
		}

		Tuple2<W, TriggerResult> firing = firings.iterator().next();

		if (!firing.f0.equals(window)) {
			throw new IllegalStateException("Trigger fired for another window.");
		}

		return firing.f1;
	}

	/**
	 * Advanced processing time and processes any timers that fire because of this. The
	 * window and {@link TriggerResult} for each firing are returned.
	 */
	public Collection<Tuple2<W, TriggerResult>> advanceProcessingTime(long time) throws Exception {
		Collection<TestInternalTimerService.Timer<Integer, W>> firedTimers =
				internalTimerService.advanceProcessingTime(time);

		Collection<Tuple2<W, TriggerResult>> result = new ArrayList<>();

		for (TestInternalTimerService.Timer<Integer, W> timer : firedTimers) {
			TestTriggerContext<Integer, W> triggerContext = new TestTriggerContext<>(
					KEY,
					timer.getNamespace(),
					internalTimerService,
					stateBackend,
					windowSerializer);

			TriggerResult triggerResult =
					trigger.onProcessingTime(timer.getTimestamp(), timer.getNamespace(), triggerContext);

			result.add(new Tuple2<>(timer.getNamespace(), triggerResult));
		}

		return result;
	}

	/**
	 * Advanced the watermark and processes any timers that fire because of this. The
	 * window and {@link TriggerResult} for each firing are returned.
	 */
	public Collection<Tuple2<W, TriggerResult>> advanceWatermark(long time) throws Exception {
		Collection<TestInternalTimerService.Timer<Integer, W>> firedTimers =
				internalTimerService.advanceWatermark(time);

		Collection<Tuple2<W, TriggerResult>> result = new ArrayList<>();

		for (TestInternalTimerService.Timer<Integer, W> timer : firedTimers) {
			TriggerResult triggerResult = invokeOnEventTime(timer);
			result.add(new Tuple2<>(timer.getNamespace(), triggerResult));
		}

		return result;
	}

	private TriggerResult invokeOnEventTime(TestInternalTimerService.Timer<Integer, W> timer) throws Exception {
		TestTriggerContext<Integer, W> triggerContext = new TestTriggerContext<>(
				KEY,
				timer.getNamespace(),
				internalTimerService,
				stateBackend,
				windowSerializer);

		return trigger.onEventTime(timer.getTimestamp(), timer.getNamespace(), triggerContext);
	}

	/**
	 * Manually invoke {@link Trigger#onEventTime(long, Window, Trigger.TriggerContext)} with
	 * the given parameters.
	 */
	public TriggerResult invokeOnEventTime(long timestamp, W window) throws Exception {
		TestInternalTimerService.Timer<Integer, W> timer =
				new TestInternalTimerService.Timer<>(timestamp, KEY, window);

		return invokeOnEventTime(timer);
	}

	/**
	 * Calls {@link Trigger#onMerge(Window, Trigger.OnMergeContext)} with the given
	 * parameters. This also calls {@link Trigger#clear(Window, Trigger.TriggerContext)} on the
	 * merged windows as does {@link WindowOperator}.
	 */
	public void mergeWindows(W targetWindow, Collection<W> mergedWindows) throws Exception {
		TestOnMergeContext<Integer, W> onMergeContext = new TestOnMergeContext<>(
				KEY,
				targetWindow,
				mergedWindows,
				internalTimerService,
				stateBackend,
				windowSerializer);
		trigger.onMerge(targetWindow, onMergeContext);

		for (W mergedWindow : mergedWindows) {
			clearTriggerState(mergedWindow);
		}
	}

	/**
	 * Calls {@link Trigger#clear(Window, Trigger.TriggerContext)} for the given window.
	 */
	public void clearTriggerState(W window) throws Exception {
		TestTriggerContext<Integer, W> triggerContext = new TestTriggerContext<>(
				KEY,
				window,
				internalTimerService,
				stateBackend,
				windowSerializer);
		trigger.clear(window, triggerContext);
	}

	private static class TestTriggerContext<K, W extends Window> implements Trigger.TriggerContext {

		protected final InternalTimerService<W> timerService;
		protected final KeyedStateBackend<Integer> stateBackend;
		protected final K key;
		protected final W window;
		protected final TypeSerializer<W> windowSerializer;

		TestTriggerContext(
				K key,
				W window,
				InternalTimerService<W> timerService,
				KeyedStateBackend<Integer> stateBackend,
				TypeSerializer<W> windowSerializer) {
			this.key = key;
			this.window = window;
			this.timerService = timerService;
			this.stateBackend = stateBackend;
			this.windowSerializer = windowSerializer;
		}

		@Override
		public long getCurrentProcessingTime() {
			return timerService.currentProcessingTime();
		}

		@Override
		public MetricGroup getMetricGroup() {
			return null;
		}

		@Override
		public long getCurrentWatermark() {
			return timerService.currentWatermark();
		}

		@Override
		public void registerProcessingTimeTimer(long time) {
			timerService.registerProcessingTimeTimer(window, time);
		}

		@Override
		public void registerEventTimeTimer(long time) {
			timerService.registerEventTimeTimer(window, time);
		}

		@Override
		public void deleteProcessingTimeTimer(long time) {
			timerService.deleteProcessingTimeTimer(window, time);
		}

		@Override
		public void deleteEventTimeTimer(long time) {
			timerService.deleteEventTimeTimer(window, time);
		}

		@Override
		public <S extends State> S getPartitionedState(StateDescriptor<S, ?> stateDescriptor) {
			try {
				return stateBackend.getPartitionedState(window, windowSerializer, stateDescriptor);
			} catch (Exception e) {
				throw new RuntimeException("Error getting state", e);
			}
		}

		@Override
		public <S extends Serializable> ValueState<S> getKeyValueState(
				String name, Class<S> stateType, S defaultState) {
			return getPartitionedState(new ValueStateDescriptor<>(name, stateType, defaultState));
		}

		@Override
		public <S extends Serializable> ValueState<S> getKeyValueState(
				String name, TypeInformation<S> stateType, S defaultState) {
			return getPartitionedState(new ValueStateDescriptor<>(name, stateType, defaultState));
		}
	}

	private static class TestOnMergeContext<K, W extends Window> extends TestTriggerContext<K, W> implements Trigger.OnMergeContext {

		private final Collection<W> mergedWindows;

		public TestOnMergeContext(
				K key,
				W targetWindow,
				Collection<W> mergedWindows,
				InternalTimerService<W> timerService,
				KeyedStateBackend<Integer> stateBackend,
				TypeSerializer<W> windowSerializer) {
			super(key, targetWindow, timerService, stateBackend, windowSerializer);

			this.mergedWindows = mergedWindows;
		}

		@Override
		public <S extends MergingState<?, ?>> void mergePartitionedState(StateDescriptor<S, ?> stateDescriptor) {
			try {
				S rawState = stateBackend.getOrCreateKeyedState(windowSerializer, stateDescriptor);

				if (rawState instanceof InternalMergingState) {
					@SuppressWarnings("unchecked")
					InternalMergingState<K, W, ?, ?, ?> mergingState = (InternalMergingState<K, W, ?, ?, ?>) rawState;
					mergingState.mergeNamespaces(window, mergedWindows);
				}
				else {
					throw new IllegalArgumentException(
							"The given state descriptor does not refer to a mergeable state (MergingState)");
				}
			}
			catch (Exception e) {
				throw new RuntimeException("Error while merging state.", e);
			}
		}
	}
}