/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.schemas.transforms; import com.google.auto.value.AutoValue; import java.io.Serializable; import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.FieldTypeDescriptors; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.schemas.utils.RowSelector; import org.apache.beam.sdk.schemas.utils.SelectHelpers; import org.apache.beam.sdk.schemas.utils.SelectHelpers.RowSelectorContainer; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineFns; import org.apache.beam.sdk.transforms.CombineFns.CoCombineResult; import org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; /** This is the builder used by {@link Group} to build up a composed {@link CombineFn}. */ @Experimental(Kind.SCHEMAS) class SchemaAggregateFn { static Inner create() { return new AutoValue_SchemaAggregateFn_Inner.Builder() .setFieldAggregations(Lists.newArrayList()) .build(); } /** Implementation of {@link #create}. */ @AutoValue abstract static class Inner extends CombineFn<Row, Object[], Row> { // Represents an aggregation of one or more fields. static class FieldAggregation<FieldT, AccumT, OutputT> implements Serializable { FieldAccessDescriptor fieldsToAggregate; private final boolean aggregateBaseValues; // The specification of the output field. private final Field outputField; // The combine function. private final CombineFn<FieldT, AccumT, OutputT> fn; // The TupleTag identifying this aggregation element in the composed combine fn. private final TupleTag<Object> combineTag; // The schema corresponding to the the subset of input fields being aggregated. @Nullable private final Schema inputSubSchema; @Nullable private final FieldAccessDescriptor flattenedFieldAccessDescriptor; // The flattened version of inputSubSchema. @Nullable private final Schema flattenedInputSubSchema; // The output schema resulting from the aggregation. private final Schema aggregationSchema; private final boolean needsFlattening; FieldAggregation( FieldAccessDescriptor fieldsToAggregate, boolean aggregateBaseValues, Field outputField, CombineFn<FieldT, AccumT, OutputT> fn, TupleTag<Object> combineTag) { this( fieldsToAggregate, aggregateBaseValues, outputField, fn, combineTag, Schema.builder().addField(outputField).build(), null); } FieldAggregation( FieldAccessDescriptor fieldsToAggregate, boolean aggregateBaseValues, Field outputField, CombineFn<FieldT, AccumT, OutputT> fn, TupleTag<Object> combineTag, Schema aggregationSchema, @Nullable Schema inputSchema) { this.aggregateBaseValues = aggregateBaseValues; if (inputSchema != null) { this.fieldsToAggregate = fieldsToAggregate.resolve(inputSchema); if (aggregateBaseValues) { Preconditions.checkArgument(fieldsToAggregate.referencesSingleField()); } this.inputSubSchema = SelectHelpers.getOutputSchema(inputSchema, this.fieldsToAggregate); this.flattenedFieldAccessDescriptor = SelectHelpers.allLeavesDescriptor(inputSubSchema, SelectHelpers.CONCAT_FIELD_NAMES); this.flattenedInputSubSchema = SelectHelpers.getOutputSchema(inputSubSchema, flattenedFieldAccessDescriptor); this.needsFlattening = !inputSchema.equals(flattenedInputSubSchema); } else { this.fieldsToAggregate = fieldsToAggregate; this.inputSubSchema = null; this.flattenedFieldAccessDescriptor = null; this.flattenedInputSubSchema = null; this.needsFlattening = false; } this.outputField = outputField; this.fn = fn; this.combineTag = combineTag; this.aggregationSchema = aggregationSchema; } // The Schema is not necessarily known when the SchemaAggregateFn is created. Once the schema // is known, resolve will be called with the proper schema. FieldAggregation<FieldT, AccumT, OutputT> resolve(Schema schema) { return new FieldAggregation<>( fieldsToAggregate, aggregateBaseValues, outputField, fn, combineTag, aggregationSchema, schema); } } abstract Builder toBuilder(); @AutoValue.Builder abstract static class Builder { abstract Builder setInputSchema(@Nullable Schema inputSchema); abstract Builder setOutputSchema(@Nullable Schema outputSchema); abstract Builder setComposedCombineFn(@Nullable ComposedCombineFn composedCombineFn); abstract Builder setFieldAggregations(List<FieldAggregation> fieldAggregations); abstract Inner build(); } abstract @Nullable Schema getInputSchema(); abstract @Nullable Schema getOutputSchema(); abstract @Nullable ComposedCombineFn getComposedCombineFn(); abstract List<FieldAggregation> getFieldAggregations(); /** Once the schema is known, this function is called by the {@link Group} transform. */ Inner withSchema(Schema inputSchema) { List<FieldAggregation> fieldAggregations = getFieldAggregations().stream() .map(f -> f.resolve(inputSchema)) .collect(Collectors.toList()); ComposedCombineFn composedCombineFn = null; for (int i = 0; i < fieldAggregations.size(); ++i) { FieldAggregation fieldAggregation = fieldAggregations.get(i); SimpleFunction<Row, ?> extractFunction; Coder extractOutputCoder; if (fieldAggregation.fieldsToAggregate.referencesSingleField()) { extractFunction = new ExtractSingleFieldFunction( inputSchema, fieldAggregation.aggregateBaseValues, fieldAggregation); FieldType fieldType = fieldAggregation.flattenedInputSubSchema.getField(0).getType(); if (fieldAggregation.aggregateBaseValues) { while (fieldType.getTypeName().isLogicalType()) { fieldType = fieldType.getLogicalType().getBaseType(); } } extractOutputCoder = SchemaCoder.coderForFieldType(fieldType); } else { extractFunction = new ExtractFieldsFunction(inputSchema, fieldAggregation); extractOutputCoder = SchemaCoder.of(fieldAggregation.inputSubSchema); } if (i == 0) { composedCombineFn = CombineFns.compose() .with( extractFunction, extractOutputCoder, fieldAggregation.fn, fieldAggregation.combineTag); } else { composedCombineFn = composedCombineFn.with( extractFunction, extractOutputCoder, fieldAggregation.fn, fieldAggregation.combineTag); } } return toBuilder() .setInputSchema(inputSchema) .setComposedCombineFn(composedCombineFn) .setFieldAggregations(fieldAggregations) .build(); } /** Aggregate all values of a set of fields into an output field. */ <CombineInputT, AccumT, CombineOutputT> Inner aggregateFields( FieldAccessDescriptor fieldsToAggregate, boolean aggregateBaseValues, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, String outputFieldName) { return aggregateFields( fieldsToAggregate, aggregateBaseValues, fn, Field.of(outputFieldName, FieldTypeDescriptors.fieldTypeForJavaType(fn.getOutputType()))); } /** Aggregate all values of a set of fields into an output field. */ <CombineInputT, AccumT, CombineOutputT> Inner aggregateFields( FieldAccessDescriptor fieldsToAggregate, boolean aggregateBaseValues, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) { List<FieldAggregation> fieldAggregations = getFieldAggregations(); TupleTag<Object> combineTag = new TupleTag<>(Integer.toString(fieldAggregations.size())); FieldAggregation fieldAggregation = new FieldAggregation<>( fieldsToAggregate, aggregateBaseValues, outputField, fn, combineTag); fieldAggregations.add(fieldAggregation); return toBuilder() .setOutputSchema(getOutputSchema(fieldAggregations)) .setFieldAggregations(fieldAggregations) .build(); } private Schema getOutputSchema(List<FieldAggregation> fieldAggregations) { Schema.Builder outputSchema = Schema.builder(); for (FieldAggregation aggregation : fieldAggregations) { outputSchema.addField(aggregation.outputField); } return outputSchema.build(); } /** Extract a single field from an input {@link Row}. */ private static class ExtractSingleFieldFunction<OutputT> extends SimpleFunction<Row, OutputT> { private final RowSelector rowSelector; private final boolean extractBaseValue; @Nullable private final RowSelector flatteningSelector; private final FieldAggregation fieldAggregation; private ExtractSingleFieldFunction( Schema inputSchema, boolean extractBaseValue, FieldAggregation fieldAggregation) { rowSelector = new RowSelectorContainer(inputSchema, fieldAggregation.fieldsToAggregate, true); this.extractBaseValue = extractBaseValue; flatteningSelector = fieldAggregation.needsFlattening ? new RowSelectorContainer( fieldAggregation.inputSubSchema, fieldAggregation.flattenedFieldAccessDescriptor, true) : null; this.fieldAggregation = fieldAggregation; } @Override public OutputT apply(Row row) { Row selected = rowSelector.select(row); if (fieldAggregation.needsFlattening) { selected = flatteningSelector.select(selected); } if (extractBaseValue && selected.getSchema().getField(0).getType().getTypeName().isLogicalType()) { return (OutputT) selected.getBaseValue(0, Object.class); } return selected.getValue(0); } } /** Extract multiple fields from an input {@link Row}. */ private static class ExtractFieldsFunction extends SimpleFunction<Row, Row> { private final RowSelector rowSelector; private final FieldAggregation fieldAggregation; private ExtractFieldsFunction(Schema inputSchema, FieldAggregation fieldAggregation) { rowSelector = new RowSelectorContainer(inputSchema, fieldAggregation.fieldsToAggregate, true); this.fieldAggregation = fieldAggregation; } @Override public Row apply(Row row) { return rowSelector.select(row); } } @Override public Object[] createAccumulator() { return getComposedCombineFn().createAccumulator(); } @Override public Object[] addInput(Object[] accumulator, Row input) { return getComposedCombineFn().addInput(accumulator, input); } @Override public Object[] mergeAccumulators(Iterable<Object[]> accumulator) { return getComposedCombineFn().mergeAccumulators(accumulator); } @Override public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<Row> inputCoder) throws CannotProvideCoderException { return getComposedCombineFn().getAccumulatorCoder(registry, inputCoder); } @Override public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<Row> inputCoder) { return SchemaCoder.of(getOutputSchema()); } @Override public Row extractOutput(Object[] accumulator) { // Build a row containing a field for every aggregate that was registered. CoCombineResult coCombineResult = getComposedCombineFn().extractOutput(accumulator); Row.Builder output = Row.withSchema(getOutputSchema()); for (FieldAggregation fieldAggregation : getFieldAggregations()) { Object aggregate = coCombineResult.get(fieldAggregation.combineTag); output.addValue(aggregate); } return output.build(); } } }