/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.state.api; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.JobSubmissionResult; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.time.Deadline; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.client.ClientUtils; import org.apache.flink.client.program.ClusterClient; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.state.FunctionInitializationContext; import org.apache.flink.runtime.state.FunctionSnapshotContext; import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.test.util.AbstractTestBase; import org.apache.flink.util.AbstractID; import org.apache.flink.util.Collector; import org.junit.Assert; import org.junit.Test; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; /** * IT case for reading state. */ public abstract class SavepointReaderITTestBase extends AbstractTestBase { static final String UID = "stateful-operator"; static final String LIST_NAME = "list"; static final String UNION_NAME = "union"; static final String BROADCAST_NAME = "broadcast"; private final ListStateDescriptor<Integer> list; private final ListStateDescriptor<Integer> union; private final MapStateDescriptor<Integer, String> broadcast; SavepointReaderITTestBase( ListStateDescriptor<Integer> list, ListStateDescriptor<Integer> union, MapStateDescriptor<Integer, String> broadcast) { this.list = list; this.union = union; this.broadcast = broadcast; } @Test public void testOperatorStateInputFormat() throws Exception { StreamExecutionEnvironment streamEnv = StreamExecutionEnvironment.getExecutionEnvironment(); streamEnv.setParallelism(4); DataStream<Integer> data = streamEnv .addSource(new SavepointSource()) .rebalance(); data .connect(data.broadcast(broadcast)) .process(new StatefulOperator(list, union, broadcast)) .uid(UID) .addSink(new DiscardingSink<>()); JobGraph jobGraph = streamEnv.getStreamGraph().getJobGraph(); String savepoint = takeSavepoint(jobGraph); ExecutionEnvironment batchEnv = ExecutionEnvironment.getExecutionEnvironment(); verifyListState(savepoint, batchEnv); verifyUnionState(savepoint, batchEnv); verifyBroadcastState(savepoint, batchEnv); } abstract DataSet<Integer> readListState(ExistingSavepoint savepoint) throws IOException; abstract DataSet<Integer> readUnionState(ExistingSavepoint savepoint) throws IOException; abstract DataSet<Tuple2<Integer, String>> readBroadcastState(ExistingSavepoint savepoint) throws IOException; private void verifyListState(String path, ExecutionEnvironment batchEnv) throws Exception { ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend()); List<Integer> listResult = readListState(savepoint).collect(); listResult.sort(Comparator.naturalOrder()); Assert.assertEquals("Unexpected elements read from list state", SavepointSource.getElements(), listResult); } private void verifyUnionState(String path, ExecutionEnvironment batchEnv) throws Exception { ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend()); List<Integer> unionResult = readUnionState(savepoint).collect(); unionResult.sort(Comparator.naturalOrder()); Assert.assertEquals("Unexpected elements read from union state", SavepointSource.getElements(), unionResult); } private void verifyBroadcastState(String path, ExecutionEnvironment batchEnv) throws Exception { ExistingSavepoint savepoint = Savepoint.load(batchEnv, path, new MemoryStateBackend()); List<Tuple2<Integer, String>> broadcastResult = readBroadcastState(savepoint) .collect(); List<Integer> broadcastStateKeys = broadcastResult. stream() .map(entry -> entry.f0) .sorted(Comparator.naturalOrder()) .collect(Collectors.toList()); List<String> broadcastStateValues = broadcastResult .stream() .map(entry -> entry.f1) .sorted(Comparator.naturalOrder()) .collect(Collectors.toList()); Assert.assertEquals("Unexpected element in broadcast state keys", SavepointSource.getElements(), broadcastStateKeys); Assert.assertEquals( "Unexpected element in broadcast state values", SavepointSource.getElements().stream().map(Object::toString).sorted().collect(Collectors.toList()), broadcastStateValues ); } private String takeSavepoint(JobGraph jobGraph) throws Exception { SavepointSource.initializeForTest(); ClusterClient<?> client = miniClusterResource.getClusterClient(); JobID jobId = jobGraph.getJobID(); Deadline deadline = Deadline.fromNow(Duration.ofMinutes(5)); String dirPath = getTempDirPath(new AbstractID().toHexString()); try { JobSubmissionResult result = ClientUtils.submitJob(client, jobGraph); boolean finished = false; while (deadline.hasTimeLeft()) { if (SavepointSource.isFinished()) { finished = true; break; } try { Thread.sleep(2L); } catch (InterruptedException ignored) { Thread.currentThread().interrupt(); } } if (!finished) { Assert.fail("Failed to initialize state within deadline"); } CompletableFuture<String> path = client.triggerSavepoint(result.getJobID(), dirPath); return path.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); } finally { client.cancel(jobId).get(); } } private static class SavepointSource implements SourceFunction<Integer> { private static volatile boolean finished; private volatile boolean running = true; private static final Integer[] elements = {1, 2, 3}; @Override public void run(SourceContext<Integer> ctx) { synchronized (ctx.getCheckpointLock()) { for (Integer element : elements) { ctx.collect(element); } finished = true; } while (running) { try { Thread.sleep(100); } catch (InterruptedException e) { // ignore } } } @Override public void cancel() { running = false; } private static void initializeForTest() { finished = false; } private static boolean isFinished() { return finished; } private static List<Integer> getElements() { return Arrays.asList(elements); } } private static class StatefulOperator extends BroadcastProcessFunction<Integer, Integer, Void> implements CheckpointedFunction { private final ListStateDescriptor<Integer> list; private final ListStateDescriptor<Integer> union; private final MapStateDescriptor<Integer, String> broadcast; private List<Integer> elements; private ListState<Integer> listState; private ListState<Integer> unionState; private StatefulOperator( ListStateDescriptor<Integer> list, ListStateDescriptor<Integer> union, MapStateDescriptor<Integer, String> broadcast) { this.list = list; this.union = union; this.broadcast = broadcast; } @Override public void open(Configuration parameters) { elements = new ArrayList<>(); } @Override public void processElement(Integer value, ReadOnlyContext ctx, Collector<Void> out) { elements.add(value); } @Override public void processBroadcastElement(Integer value, Context ctx, Collector<Void> out) throws Exception { ctx.getBroadcastState(broadcast).put(value, value.toString()); } @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { listState.clear(); listState.addAll(elements); unionState.clear(); unionState.addAll(elements); } @Override public void initializeState(FunctionInitializationContext context) throws Exception { listState = context.getOperatorStateStore().getListState(list); unionState = context.getOperatorStateStore().getUnionListState(union); } } }