/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.executiongraph; import org.apache.flink.api.common.JobStatus; import org.apache.flink.api.common.time.Deadline; import org.apache.flink.api.common.time.Time; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.restart.NoRestartStrategy; import org.apache.flink.runtime.executiongraph.restart.RestartStrategy; import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway; import org.apache.flink.runtime.executiongraph.utils.SimpleSlotProvider; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.ScheduleMode; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway; import org.apache.flink.runtime.jobmaster.LogicalSlot; import org.apache.flink.runtime.jobmaster.slotpool.SlotProvider; import org.apache.flink.runtime.taskmanager.TaskManagerLocation; import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.runtime.testutils.DirectScheduledExecutorService; import javax.annotation.Nullable; import java.lang.reflect.Field; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; import java.util.function.Predicate; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; /** * A collection of utility methods for testing the ExecutionGraph and its related classes. */ public class ExecutionGraphTestUtils { private static final Time DEFAULT_TIMEOUT = AkkaUtils.getDefaultTimeout(); // ------------------------------------------------------------------------ // reaching states // ------------------------------------------------------------------------ /** * Waits until the Job has reached a certain state. * * <p>This method is based on polling and might miss very fast state transitions! */ public static void waitUntilJobStatus(ExecutionGraph eg, JobStatus status, long maxWaitMillis) throws TimeoutException { checkNotNull(eg); checkNotNull(status); checkArgument(maxWaitMillis >= 0); // this is a poor implementation - we may want to improve it eventually final long deadline = maxWaitMillis == 0 ? Long.MAX_VALUE : System.nanoTime() + (maxWaitMillis * 1_000_000); while (eg.getState() != status && System.nanoTime() < deadline) { try { Thread.sleep(2); } catch (InterruptedException ignored) {} } if (System.nanoTime() >= deadline) { throw new TimeoutException( String.format("The job did not reach status %s in time. Current status is %s.", status, eg.getState())); } } /** * Waits until the Execution has reached a certain state. * * <p>This method is based on polling and might miss very fast state transitions! */ public static void waitUntilExecutionState(Execution execution, ExecutionState state, long maxWaitMillis) throws TimeoutException { checkNotNull(execution); checkNotNull(state); checkArgument(maxWaitMillis >= 0); // this is a poor implementation - we may want to improve it eventually final long deadline = maxWaitMillis == 0 ? Long.MAX_VALUE : System.nanoTime() + (maxWaitMillis * 1_000_000); while (execution.getState() != state && System.nanoTime() < deadline) { try { Thread.sleep(2); } catch (InterruptedException ignored) {} } if (System.nanoTime() >= deadline) { throw new TimeoutException( String.format("The execution did not reach state %s in time. Current state is %s.", state, execution.getState())); } } /** * Waits until the ExecutionVertex has reached a certain state. * * <p>This method is based on polling and might miss very fast state transitions! */ public static void waitUntilExecutionVertexState(ExecutionVertex executionVertex, ExecutionState state, long maxWaitMillis) throws TimeoutException { checkNotNull(executionVertex); checkNotNull(state); checkArgument(maxWaitMillis >= 0); // this is a poor implementation - we may want to improve it eventually final long deadline = maxWaitMillis == 0 ? Long.MAX_VALUE : System.nanoTime() + (maxWaitMillis * 1_000_000); while (true) { Execution execution = executionVertex.getCurrentExecutionAttempt(); if (execution == null || (execution.getState() != state && System.nanoTime() < deadline)) { try { Thread.sleep(2); } catch (InterruptedException ignored) { } } else { break; } if (System.nanoTime() >= deadline) { if (execution != null) { throw new TimeoutException( String.format("The execution vertex did not reach state %s in time. Current state is %s.", state, execution.getState())); } else { throw new TimeoutException( "Cannot get current execution attempt of " + executionVertex + '.'); } } } } /** * Waits until all executions fulfill the given predicate. * * @param executionGraph for which to check the executions * @param executionPredicate predicate which is to be fulfilled * @param maxWaitMillis timeout for the wait operation * @throws TimeoutException if the executions did not reach the target state in time */ public static void waitForAllExecutionsPredicate( ExecutionGraph executionGraph, Predicate<AccessExecution> executionPredicate, long maxWaitMillis) throws TimeoutException { final Predicate<AccessExecutionGraph> allExecutionsPredicate = allExecutionsPredicate(executionPredicate); final Deadline deadline = Deadline.fromNow(Duration.ofMillis(maxWaitMillis)); boolean predicateResult; do { predicateResult = allExecutionsPredicate.test(executionGraph); if (!predicateResult) { try { Thread.sleep(2L); } catch (InterruptedException ignored) { Thread.currentThread().interrupt(); } } } while (!predicateResult && deadline.hasTimeLeft()); if (!predicateResult) { throw new TimeoutException("Not all executions fulfilled the predicate in time."); } } public static Predicate<AccessExecutionGraph> allExecutionsPredicate(final Predicate<AccessExecution> executionPredicate) { return accessExecutionGraph -> { final Iterable<? extends AccessExecutionVertex> allExecutionVertices = accessExecutionGraph.getAllExecutionVertices(); for (AccessExecutionVertex executionVertex : allExecutionVertices) { final AccessExecution currentExecutionAttempt = executionVertex.getCurrentExecutionAttempt(); if (currentExecutionAttempt == null || !executionPredicate.test(currentExecutionAttempt)) { return false; } } return true; }; } public static Predicate<AccessExecution> isInExecutionState(ExecutionState executionState) { return (AccessExecution execution) -> execution.getState() == executionState; } /** * Takes all vertices in the given ExecutionGraph and switches their current * execution to RUNNING. */ public static void switchAllVerticesToRunning(ExecutionGraph eg) { for (ExecutionVertex vertex : eg.getAllExecutionVertices()) { vertex.getCurrentExecutionAttempt().switchToRunning(); } } /** * Takes all vertices in the given ExecutionGraph and attempts to move them * from CANCELING to CANCELED. */ public static void completeCancellingForAllVertices(ExecutionGraph eg) { for (ExecutionVertex vertex : eg.getAllExecutionVertices()) { vertex.getCurrentExecutionAttempt().completeCancelling(); } } /** * Takes all vertices in the given ExecutionGraph and switches their current * execution to FINISHED. */ public static void finishAllVertices(ExecutionGraph eg) { for (ExecutionVertex vertex : eg.getAllExecutionVertices()) { vertex.getCurrentExecutionAttempt().markFinished(); } } /** * Checks that all execution are in state DEPLOYING and then switches them * to state RUNNING. */ public static void switchToRunning(ExecutionGraph eg) { // check that all execution are in state DEPLOYING for (ExecutionVertex ev : eg.getAllExecutionVertices()) { final Execution exec = ev.getCurrentExecutionAttempt(); final ExecutionState executionState = exec.getState(); assert executionState == ExecutionState.DEPLOYING : "Expected executionState to be DEPLOYING, was: " + executionState; } // switch executions to RUNNING for (ExecutionVertex ev : eg.getAllExecutionVertices()) { final Execution exec = ev.getCurrentExecutionAttempt(); exec.switchToRunning(); } } // ------------------------------------------------------------------------ // state modifications // ------------------------------------------------------------------------ public static void setVertexState(ExecutionVertex vertex, ExecutionState state) { try { Execution exec = vertex.getCurrentExecutionAttempt(); Field f = Execution.class.getDeclaredField("state"); f.setAccessible(true); f.set(exec, state); } catch (Exception e) { throw new RuntimeException("Modifying the state failed", e); } } public static void setVertexResource(ExecutionVertex vertex, LogicalSlot slot) { Execution exec = vertex.getCurrentExecutionAttempt(); if (!exec.tryAssignResource(slot)) { throw new RuntimeException("Could not assign resource."); } } // ------------------------------------------------------------------------ // Mocking ExecutionGraph // ------------------------------------------------------------------------ /** * Creates an execution graph with on job vertex of parallelism 10 that does no restarts. */ public static ExecutionGraph createSimpleTestGraph() throws Exception { return createSimpleTestGraph(new NoRestartStrategy()); } /** * Creates an execution graph with on job vertex of parallelism 10, using the given * restart strategy. */ public static ExecutionGraph createSimpleTestGraph(RestartStrategy restartStrategy) throws Exception { JobVertex vertex = createNoOpVertex(10); return createSimpleTestGraph(new SimpleAckingTaskManagerGateway(), restartStrategy, vertex); } /** * Creates an execution graph containing the given vertices. * * <p>The execution graph uses {@link NoRestartStrategy} as the restart strategy. */ public static ExecutionGraph createSimpleTestGraph(JobVertex... vertices) throws Exception { return createSimpleTestGraph(new SimpleAckingTaskManagerGateway(), new NoRestartStrategy(), vertices); } /** * Creates an execution graph containing the given vertices and the given restart strategy. */ public static ExecutionGraph createSimpleTestGraph( TaskManagerGateway taskManagerGateway, RestartStrategy restartStrategy, JobVertex... vertices) throws Exception { int numSlotsNeeded = 0; for (JobVertex vertex : vertices) { numSlotsNeeded += vertex.getParallelism(); } SlotProvider slotProvider = new SimpleSlotProvider(numSlotsNeeded, taskManagerGateway); return createSimpleTestGraph(slotProvider, restartStrategy, vertices); } public static ExecutionGraph createSimpleTestGraph( SlotProvider slotProvider, RestartStrategy restartStrategy, JobVertex... vertices) throws Exception { return createExecutionGraph(slotProvider, restartStrategy, TestingUtils.defaultExecutor(), vertices); } public static ExecutionGraph createExecutionGraph( SlotProvider slotProvider, RestartStrategy restartStrategy, ScheduledExecutorService executor, JobVertex... vertices) throws Exception { return createExecutionGraph(slotProvider, restartStrategy, executor, Time.seconds(10L), vertices); } public static ExecutionGraph createExecutionGraph( SlotProvider slotProvider, RestartStrategy restartStrategy, ScheduledExecutorService executor, Time timeout, JobVertex... vertices) throws Exception { checkNotNull(restartStrategy); checkNotNull(vertices); checkNotNull(timeout); return TestingExecutionGraphBuilder .newBuilder() .setJobGraph(new JobGraph(vertices)) .setFutureExecutor(executor) .setIoExecutor(executor) .setSlotProvider(slotProvider) .setAllocationTimeout(timeout) .setRpcTimeout(timeout) .setRestartStrategy(restartStrategy) .build(); } public static JobVertex createNoOpVertex(int parallelism) { return createNoOpVertex("vertex", parallelism); } public static JobVertex createNoOpVertex(String name, int parallelism) { JobVertex vertex = new JobVertex(name); vertex.setInvokableClass(NoOpInvokable.class); vertex.setParallelism(parallelism); return vertex; } // ------------------------------------------------------------------------ // utility mocking methods // ------------------------------------------------------------------------ public static JobVertex createJobVertex(String task1, int numTasks, Class<NoOpInvokable> invokable) { JobVertex groupVertex = new JobVertex(task1); groupVertex.setInvokableClass(invokable); groupVertex.setParallelism(numTasks); return groupVertex; } public static ExecutionJobVertex getExecutionJobVertex( JobVertexID id, ScheduledExecutorService executor) throws Exception { return getExecutionJobVertex(id, executor, ScheduleMode.LAZY_FROM_SOURCES); } public static ExecutionJobVertex getExecutionJobVertex( JobVertexID id, ScheduledExecutorService executor, ScheduleMode scheduleMode) throws Exception { return getExecutionJobVertex(id, 1, null, executor, scheduleMode); } public static ExecutionJobVertex getExecutionJobVertex( JobVertexID id, int parallelism, @Nullable SlotSharingGroup slotSharingGroup, ScheduledExecutorService executor, ScheduleMode scheduleMode) throws Exception { JobVertex ajv = new JobVertex("TestVertex", id); ajv.setInvokableClass(AbstractInvokable.class); ajv.setParallelism(parallelism); if (slotSharingGroup != null) { ajv.setSlotSharingGroup(slotSharingGroup); } JobGraph jobGraph = new JobGraph(ajv); jobGraph.setScheduleMode(scheduleMode); ExecutionGraph graph = TestingExecutionGraphBuilder .newBuilder() .setJobGraph(jobGraph) .setIoExecutor(executor) .setFutureExecutor(executor) .build(); graph.start(ComponentMainThreadExecutorServiceAdapter.forMainThread()); return new ExecutionJobVertex(graph, ajv, 1, AkkaUtils.getDefaultTimeout()); } public static ExecutionJobVertex getExecutionJobVertex(JobVertexID id) throws Exception { return getExecutionJobVertex(id, new DirectScheduledExecutorService()); } public static Execution getExecution() throws Exception { final ExecutionJobVertex ejv = getExecutionJobVertex(new JobVertexID()); return ejv.getTaskVertices()[0].getCurrentExecutionAttempt(); } public static Execution getExecution(final TaskManagerLocation... preferredLocations) throws Exception { return getExecution(mapToPreferredLocationFutures(preferredLocations)); } private static Collection<CompletableFuture<TaskManagerLocation>> mapToPreferredLocationFutures( final TaskManagerLocation... preferredLocations) { final Collection<CompletableFuture<TaskManagerLocation>> preferredLocationFutures = new ArrayList<>(); for (TaskManagerLocation preferredLocation : preferredLocations) { preferredLocationFutures.add(CompletableFuture.completedFuture(preferredLocation)); } return preferredLocationFutures; } public static Execution getExecution( final Collection<CompletableFuture<TaskManagerLocation>> preferredLocationFutures) throws Exception { final ExecutionJobVertex ejv = getExecutionJobVertex(new JobVertexID()); final TestExecutionVertex ev = new TestExecutionVertex(ejv, 0, new IntermediateResult[0], DEFAULT_TIMEOUT); ev.setPreferredLocationFutures(preferredLocationFutures); return ev.getCurrentExecutionAttempt(); } public static Execution getExecution( final JobVertexID jid, final int subtaskIndex, final int numTasks, final SlotSharingGroup slotSharingGroup) throws Exception { return getExecution(jid, subtaskIndex, numTasks, slotSharingGroup, null); } public static Execution getExecution( final JobVertexID jid, final int subtaskIndex, final int numTasks, final SlotSharingGroup slotSharingGroup, @Nullable final TaskManagerLocation... locations) throws Exception { final ExecutionJobVertex ejv = getExecutionJobVertex( jid, numTasks, slotSharingGroup, new DirectScheduledExecutorService(), ScheduleMode.LAZY_FROM_SOURCES); final TestExecutionVertex ev = new TestExecutionVertex( ejv, subtaskIndex, new IntermediateResult[0], DEFAULT_TIMEOUT); if (locations != null) { ev.setPreferredLocationFutures(mapToPreferredLocationFutures(locations)); } return ev.getCurrentExecutionAttempt(); } // ------------------------------------------------------------------------ // graph vertex verifications // ------------------------------------------------------------------------ /** * Verifies the generated {@link ExecutionJobVertex} for a given {@link JobVertex} in a {@link ExecutionGraph}. * * @param executionGraph the generated execution graph * @param originJobVertex the vertex to verify for * @param inputJobVertices upstream vertices of the verified vertex, used to check inputs of generated vertex * @param outputJobVertices downstream vertices of the verified vertex, used to * check produced data sets of generated vertex */ public static void verifyGeneratedExecutionJobVertex( ExecutionGraph executionGraph, JobVertex originJobVertex, @Nullable List<JobVertex> inputJobVertices, @Nullable List<JobVertex> outputJobVertices) { ExecutionJobVertex ejv = executionGraph.getAllVertices().get(originJobVertex.getID()); assertNotNull(ejv); // verify basic properties assertEquals(originJobVertex.getParallelism(), ejv.getParallelism()); assertEquals(executionGraph.getJobID(), ejv.getJobId()); assertEquals(originJobVertex.getID(), ejv.getJobVertexId()); assertEquals(originJobVertex, ejv.getJobVertex()); // verify produced data sets if (outputJobVertices == null) { assertEquals(0, ejv.getProducedDataSets().length); } else { assertEquals(outputJobVertices.size(), ejv.getProducedDataSets().length); for (int i = 0; i < outputJobVertices.size(); i++) { assertEquals(originJobVertex.getProducedDataSets().get(i).getId(), ejv.getProducedDataSets()[i].getId()); assertEquals(originJobVertex.getParallelism(), ejv.getProducedDataSets()[0].getPartitions().length); } } // verify task vertices for their basic properties and their inputs assertEquals(originJobVertex.getParallelism(), ejv.getTaskVertices().length); int subtaskIndex = 0; for (ExecutionVertex ev : ejv.getTaskVertices()) { assertEquals(executionGraph.getJobID(), ev.getJobId()); assertEquals(originJobVertex.getID(), ev.getJobvertexId()); assertEquals(originJobVertex.getParallelism(), ev.getTotalNumberOfParallelSubtasks()); assertEquals(subtaskIndex, ev.getParallelSubtaskIndex()); if (inputJobVertices == null) { assertEquals(0, ev.getNumberOfInputs()); } else { assertEquals(inputJobVertices.size(), ev.getNumberOfInputs()); for (int i = 0; i < inputJobVertices.size(); i++) { ExecutionEdge[] inputEdges = ev.getInputEdges(i); assertEquals(inputJobVertices.get(i).getParallelism(), inputEdges.length); int expectedPartitionNum = 0; for (ExecutionEdge inEdge : inputEdges) { assertEquals(i, inEdge.getInputNum()); assertEquals(expectedPartitionNum, inEdge.getSource().getPartitionNumber()); expectedPartitionNum++; } } } subtaskIndex++; } } }