/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.StateObjectCollection; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.DoneFuture; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.LocalRecoveryConfig; import org.apache.flink.runtime.state.LocalRecoveryDirectoryProviderImpl; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.SnapshotResult; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.runtime.state.TaskLocalStateStore; import org.apache.flink.runtime.state.TaskLocalStateStoreImpl; import org.apache.flink.runtime.state.TaskStateManagerImpl; import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.runtime.taskmanager.TestCheckpointResponder; import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures; import org.apache.flink.util.TestLogger; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import javax.annotation.Nonnegative; import javax.annotation.Nullable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.RunnableFuture; import java.util.concurrent.atomic.AtomicBoolean; import static org.mockito.Mockito.mock; /** * Test for forwarding of state reporting to and from {@link org.apache.flink.runtime.state.TaskStateManager}. */ public class LocalStateForwardingTest extends TestLogger { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); /** * This tests the forwarding of jm and tm-local state from the futures reported by the backends, through the * async checkpointing thread to the {@link org.apache.flink.runtime.state.TaskStateManager}. */ @Test public void testReportingFromSnapshotToTaskStateManager() { TestTaskStateManager taskStateManager = new TestTaskStateManager(); StreamMockEnvironment streamMockEnvironment = new StreamMockEnvironment( new Configuration(), new Configuration(), new ExecutionConfig(), 1024 * 1024, new MockInputSplitProvider(), 0, taskStateManager); StreamTask testStreamTask = new StreamTaskTest.NoOpStreamTask(streamMockEnvironment); CheckpointMetaData checkpointMetaData = new CheckpointMetaData(0L, 0L); CheckpointMetrics checkpointMetrics = new CheckpointMetrics(); Map<OperatorID, OperatorSnapshotFutures> snapshots = new HashMap<>(1); OperatorSnapshotFutures osFuture = new OperatorSnapshotFutures(); osFuture.setKeyedStateManagedFuture(createSnapshotResult(KeyedStateHandle.class)); osFuture.setKeyedStateRawFuture(createSnapshotResult(KeyedStateHandle.class)); osFuture.setOperatorStateManagedFuture(createSnapshotResult(OperatorStateHandle.class)); osFuture.setOperatorStateRawFuture(createSnapshotResult(OperatorStateHandle.class)); OperatorID operatorID = new OperatorID(); snapshots.put(operatorID, osFuture); StreamTask.AsyncCheckpointRunnable checkpointRunnable = new StreamTask.AsyncCheckpointRunnable( testStreamTask, snapshots, checkpointMetaData, checkpointMetrics, 0L); checkpointRunnable.run(); TaskStateSnapshot lastJobManagerTaskStateSnapshot = taskStateManager.getLastJobManagerTaskStateSnapshot(); TaskStateSnapshot lastTaskManagerTaskStateSnapshot = taskStateManager.getLastTaskManagerTaskStateSnapshot(); OperatorSubtaskState jmState = lastJobManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID); OperatorSubtaskState tmState = lastTaskManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID); performCheck(osFuture.getKeyedStateManagedFuture(), jmState.getManagedKeyedState(), tmState.getManagedKeyedState()); performCheck(osFuture.getKeyedStateRawFuture(), jmState.getRawKeyedState(), tmState.getRawKeyedState()); performCheck(osFuture.getOperatorStateManagedFuture(), jmState.getManagedOperatorState(), tmState.getManagedOperatorState()); performCheck(osFuture.getOperatorStateRawFuture(), jmState.getRawOperatorState(), tmState.getRawOperatorState()); } /** * This tests that state that was reported to the {@link org.apache.flink.runtime.state.TaskStateManager} is also * reported to {@link org.apache.flink.runtime.taskmanager.CheckpointResponder} and {@link TaskLocalStateStoreImpl}. */ @Test public void testReportingFromTaskStateManagerToResponderAndTaskLocalStateStore() throws Exception { final JobID jobID = new JobID(); final AllocationID allocationID = new AllocationID(); final ExecutionAttemptID executionAttemptID = new ExecutionAttemptID(); final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(42L, 4711L); final CheckpointMetrics checkpointMetrics = new CheckpointMetrics(); final int subtaskIdx = 42; JobVertexID jobVertexID = new JobVertexID(); TaskStateSnapshot jmSnapshot = new TaskStateSnapshot(); TaskStateSnapshot tmSnapshot = new TaskStateSnapshot(); final AtomicBoolean jmReported = new AtomicBoolean(false); final AtomicBoolean tmReported = new AtomicBoolean(false); TestCheckpointResponder checkpointResponder = new TestCheckpointResponder() { @Override public void acknowledgeCheckpoint( JobID lJobID, ExecutionAttemptID lExecutionAttemptID, long lCheckpointId, CheckpointMetrics lCheckpointMetrics, TaskStateSnapshot lSubtaskState) { Assert.assertEquals(jobID, lJobID); Assert.assertEquals(executionAttemptID, lExecutionAttemptID); Assert.assertEquals(checkpointMetaData.getCheckpointId(), lCheckpointId); Assert.assertEquals(checkpointMetrics, lCheckpointMetrics); jmReported.set(true); } }; Executor executor = Executors.directExecutor(); LocalRecoveryDirectoryProviderImpl directoryProvider = new LocalRecoveryDirectoryProviderImpl( temporaryFolder.newFolder(), jobID, jobVertexID, subtaskIdx); LocalRecoveryConfig localRecoveryConfig = new LocalRecoveryConfig(true, directoryProvider); TaskLocalStateStore taskLocalStateStore = new TaskLocalStateStoreImpl(jobID, allocationID, jobVertexID, subtaskIdx, localRecoveryConfig, executor) { @Override public void storeLocalState( @Nonnegative long checkpointId, @Nullable TaskStateSnapshot localState) { Assert.assertEquals(tmSnapshot, localState); tmReported.set(true); } }; TaskStateManagerImpl taskStateManager = new TaskStateManagerImpl( jobID, executionAttemptID, taskLocalStateStore, null, checkpointResponder); taskStateManager.reportTaskStateSnapshots( checkpointMetaData, checkpointMetrics, jmSnapshot, tmSnapshot); Assert.assertTrue("Reporting for JM state was not called.", jmReported.get()); Assert.assertTrue("Reporting for TM state was not called.", tmReported.get()); } private static <T extends StateObject> void performCheck( Future<SnapshotResult<T>> resultFuture, StateObjectCollection<T> jmState, StateObjectCollection<T> tmState) { SnapshotResult<T> snapshotResult; try { snapshotResult = resultFuture.get(); } catch (Exception e) { throw new RuntimeException(e); } Assert.assertEquals( snapshotResult.getJobManagerOwnedSnapshot(), jmState.iterator().next()); Assert.assertEquals( snapshotResult.getTaskLocalSnapshot(), tmState.iterator().next()); } private static <T extends StateObject> RunnableFuture<SnapshotResult<T>> createSnapshotResult(Class<T> clazz) { return DoneFuture.of(SnapshotResult.withLocalState(mock(clazz), mock(clazz))); } }