/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.runners.dataflow.util.Structs.getBytes; import com.google.api.services.dataflow.model.PartialGroupByKeyInstruction; import com.google.api.services.dataflow.model.SideInputInfo; import java.util.List; import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.core.GlobalCombineFnRunner; import org.apache.beam.runners.core.GlobalCombineFnRunners; import org.apache.beam.runners.core.NullSideInputReader; import org.apache.beam.runners.core.SideInputReader; import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.runners.dataflow.worker.util.common.worker.GroupingTable; import org.apache.beam.runners.dataflow.worker.util.common.worker.GroupingTables; import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn; import org.apache.beam.runners.dataflow.worker.util.common.worker.Receiver; import org.apache.beam.runners.dataflow.worker.util.common.worker.SimplePartialGroupByKeyParDoFn; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.SdkHarnessOptions; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream; import org.joda.time.Instant; /** A factory class that creates {@link ParDoFn} for {@link PartialGroupByKeyInstruction}. */ public class PartialGroupByKeyParDoFns { public static <K, InputT, AccumT> ParDoFn create( PipelineOptions options, KvCoder<K, ?> inputElementCoder, @Nullable CloudObject cloudUserFn, @Nullable List<SideInputInfo> sideInputInfos, List<Receiver> receivers, DataflowExecutionContext<?> executionContext, DataflowOperationContext operationContext) throws Exception { AppliedCombineFn<K, InputT, AccumT, ?> combineFn; SideInputReader sideInputReader; StepContext stepContext; if (cloudUserFn == null) { combineFn = null; sideInputReader = NullSideInputReader.empty(); stepContext = null; } else { Object deserializedFn = SerializableUtils.deserializeFromByteArray( getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "serialized combine fn"); @SuppressWarnings("unchecked") AppliedCombineFn<K, InputT, AccumT, ?> combineFnUnchecked = ((AppliedCombineFn<K, InputT, AccumT, ?>) deserializedFn); combineFn = combineFnUnchecked; sideInputReader = executionContext.getSideInputReader( sideInputInfos, combineFn.getSideInputViews(), operationContext); stepContext = executionContext.getStepContext(operationContext); } return create( options, inputElementCoder, combineFn, sideInputReader, receivers.get(0), stepContext); } @VisibleForTesting static <K, InputT, AccumT> ParDoFn create( PipelineOptions options, KvCoder<K, ?> inputElementCoder, @Nullable AppliedCombineFn<K, InputT, AccumT, ?> combineFn, SideInputReader sideInputReader, Receiver receiver, @Nullable StepContext stepContext) throws Exception { long maxSizeBytes = options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * (1024L * 1024L); Coder<K> keyCoder = inputElementCoder.getKeyCoder(); Coder<?> valueCoder = inputElementCoder.getValueCoder(); if (combineFn == null) { @SuppressWarnings("unchecked") Coder<InputT> inputCoder = (Coder<InputT>) valueCoder; GroupingTable<?, ?, ?> groupingTable = GroupingTables.bufferingAndSampling( new WindowingCoderGroupingKeyCreator<>(keyCoder), PairInfo.create(), new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)), new CoderSizeEstimator<>(inputCoder), 0.001, /*sizeEstimatorSampleRate*/ maxSizeBytes /*maxSizeBytes*/); return new SimplePartialGroupByKeyParDoFn<>(groupingTable, receiver); } else { GroupingTables.Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner = new ValueCombiner<>( GlobalCombineFnRunners.create(combineFn.getFn()), sideInputReader, options); GroupingTable<WindowedValue<K>, InputT, AccumT> groupingTable = GroupingTables.combiningAndSampling( new WindowingCoderGroupingKeyCreator<>(keyCoder), PairInfo.create(), valueCombiner, new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)), new CoderSizeEstimator<>(combineFn.getAccumulatorCoder()), 0.001, /*sizeEstimatorSampleRate*/ maxSizeBytes /*maxSizeBytes*/); if (sideInputReader.isEmpty()) { return new SimplePartialGroupByKeyParDoFn<>(groupingTable, receiver); } else if (options.as(StreamingOptions.class).isStreaming()) { StreamingSideInputFetcher<KV<K, InputT>, ?> sideInputFetcher = new StreamingSideInputFetcher<>( combineFn.getSideInputViews(), combineFn.getKvCoder(), combineFn.getWindowingStrategy(), (StreamingModeExecutionContext.StreamingModeStepContext) stepContext); return new StreamingSideInputPGBKParDoFn<>(groupingTable, receiver, sideInputFetcher); } else { return new BatchSideInputPGBKParDoFn<>(groupingTable, receiver); } } } /** Implements PGBKOp.Combiner via Combine.KeyedCombineFn. */ public static class ValueCombiner<K, InputT, AccumT, OutputT> implements GroupingTables.Combiner<WindowedValue<K>, InputT, AccumT, OutputT> { private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn; private final SideInputReader sideInputReader; private final PipelineOptions options; private ValueCombiner( GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn, SideInputReader sideInputReader, PipelineOptions options) { this.combineFn = combineFn; this.sideInputReader = sideInputReader; this.options = options; } @Override public AccumT createAccumulator(WindowedValue<K> windowedKey) { return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows()); } @Override public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) { return this.combineFn.addInput( accumulator, value, options, sideInputReader, windowedKey.getWindows()); } @Override public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) { return this.combineFn.mergeAccumulators( accumulators, options, sideInputReader, windowedKey.getWindows()); } @Override public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) { return this.combineFn.compact( accumulator, options, sideInputReader, windowedKey.getWindows()); } @Override public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) { return this.combineFn.extractOutput( accumulator, options, sideInputReader, windowedKey.getWindows()); } } /** Implements PGBKOp.PairInfo via KVs. */ public static class PairInfo implements GroupingTables.PairInfo { private static PairInfo theInstance = new PairInfo(); public static PairInfo create() { return theInstance; } private PairInfo() {} @Override public Object getKeyFromInputPair(Object pair) { @SuppressWarnings("unchecked") WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair; return windowedKv.withValue(windowedKv.getValue().getKey()); } @Override public Object getValueFromInputPair(Object pair) { @SuppressWarnings("unchecked") WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair; return windowedKv.getValue().getValue(); } @Override public Object makeOutputPair(Object key, Object values) { WindowedValue<?> windowedKey = (WindowedValue<?>) key; return windowedKey.withValue(KV.of(windowedKey.getValue(), values)); } } /** Implements PGBKOp.GroupingKeyCreator via Coder. */ // TODO: Actually support window merging in the combiner table. public static class WindowingCoderGroupingKeyCreator<K> implements GroupingTables.GroupingKeyCreator<WindowedValue<K>> { private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE; private final Coder<K> coder; public WindowingCoderGroupingKeyCreator(Coder<K> coder) { this.coder = coder; } @Override public Object createGroupingKey(WindowedValue<K> key) throws Exception { // Ignore timestamp for grouping purposes. // The PGBK output will inherit the timestamp of one of its inputs. return WindowedValue.of( coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane()); } } /** Implements PGBKOp.SizeEstimator via Coder. */ public static class CoderSizeEstimator<T> implements GroupingTables.SizeEstimator<T> { /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */ private static class Observer extends ElementByteSizeObserver { private long observedSize = 0; @Override protected void reportElementSize(long elementSize) { observedSize += elementSize; } } final Coder<T> coder; public CoderSizeEstimator(Coder<T> coder) { this.coder = coder; } @Override public long estimateSize(T value) throws Exception { // First try using byte size observer Observer observer = new Observer(); coder.registerByteSizeObserver(value, observer); if (!observer.getIsLazy()) { observer.advance(); return observer.observedSize; } else { // Coder byte size observation is lazy (requires iteration for observation) so fall back to // counting output stream CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream()); coder.encode(value, os); return os.getCount(); } } } static class BatchSideInputPGBKParDoFn<K, InputT, AccumT, W extends BoundedWindow> implements ParDoFn { private final GroupingTable<WindowedValue<K>, InputT, AccumT> groupingTable; private final Receiver receiver; public BatchSideInputPGBKParDoFn( GroupingTable<WindowedValue<K>, InputT, AccumT> groupingTable, Receiver receiver) { this.groupingTable = groupingTable; this.receiver = receiver; } @Override public void startBundle(Receiver... receivers) throws Exception {} @Override public void processElement(Object elem) throws Exception { @SuppressWarnings({"unchecked"}) WindowedValue<KV<K, InputT>> input = (WindowedValue<KV<K, InputT>>) elem; for (BoundedWindow w : input.getWindows()) { WindowedValue<KV<K, InputT>> windowsExpandedInput = WindowedValue.of(input.getValue(), input.getTimestamp(), w, input.getPane()); groupingTable.put(windowsExpandedInput, receiver); } } @Override public void processTimers() {} @Override public void finishBundle() throws Exception { groupingTable.flush(receiver); } @Override public void abort() throws Exception {} } static class StreamingSideInputPGBKParDoFn<K, InputT, AccumT, W extends BoundedWindow> implements ParDoFn { private final GroupingTable<WindowedValue<K>, InputT, AccumT> groupingTable; private final Receiver receiver; private final StreamingSideInputFetcher<KV<K, InputT>, W> sideInputFetcher; StreamingSideInputPGBKParDoFn( GroupingTable<WindowedValue<K>, InputT, AccumT> groupingTable, Receiver receiver, StreamingSideInputFetcher<KV<K, InputT>, W> sideInputFetcher) { this.groupingTable = groupingTable; this.receiver = receiver; this.sideInputFetcher = sideInputFetcher; } @Override public void startBundle(Receiver... receivers) throws Exception { // Find the set of ready windows. Set<W> readyWindows = sideInputFetcher.getReadyWindows(); Iterable<BagState<WindowedValue<KV<K, InputT>>>> elementsBags = sideInputFetcher.prefetchElements(readyWindows); // Put elements into the grouping table now that all side inputs are ready. for (BagState<WindowedValue<KV<K, InputT>>> elementsBag : elementsBags) { Iterable<WindowedValue<KV<K, InputT>>> elements = elementsBag.read(); for (WindowedValue<KV<K, InputT>> elem : elements) { groupingTable.put(elem, receiver); } elementsBag.clear(); } sideInputFetcher.releaseBlockedWindows(readyWindows); } @Override public void processElement(Object elem) throws Exception { @SuppressWarnings({"unchecked"}) WindowedValue<KV<K, InputT>> input = (WindowedValue<KV<K, InputT>>) elem; for (BoundedWindow w : input.getWindows()) { WindowedValue<KV<K, InputT>> windowsExpandedInput = WindowedValue.of(input.getValue(), input.getTimestamp(), w, input.getPane()); if (!sideInputFetcher.storeIfBlocked(windowsExpandedInput)) { groupingTable.put(windowsExpandedInput, receiver); } } } @Override public void processTimers() {} @Override public void finishBundle() throws Exception { groupingTable.flush(receiver); sideInputFetcher.persist(); } @Override public void abort() throws Exception {} } }