/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.nemo.compiler.frontend.spark.core; import org.apache.nemo.client.JobLauncher; import org.apache.nemo.common.dag.DAG; import org.apache.nemo.common.dag.DAGBuilder; import org.apache.nemo.common.ir.IRDAG; import org.apache.nemo.common.ir.edge.IREdge; import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty; import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty; import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.ir.vertex.LoopVertex; import org.apache.nemo.common.ir.vertex.OperatorVertex; import org.apache.nemo.compiler.frontend.spark.SparkBroadcastVariables; import org.apache.nemo.compiler.frontend.spark.SparkKeyExtractor; import org.apache.nemo.compiler.frontend.spark.coder.SparkDecoderFactory; import org.apache.nemo.compiler.frontend.spark.coder.SparkEncoderFactory; import org.apache.nemo.compiler.frontend.spark.transform.CollectTransform; import org.apache.nemo.compiler.frontend.spark.transform.GroupByKeyTransform; import org.apache.nemo.compiler.frontend.spark.transform.ReduceByKeyTransform; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.serializer.JavaSerializer; import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import scala.Function1; import scala.Tuple2; import scala.collection.JavaConverters; import scala.collection.TraversableOnce; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; import java.util.Stack; /** * Utility class for RDDs. */ public final class SparkFrontendUtils { private static final KeyExtractorProperty SPARK_KEY_EXTRACTOR_PROP = KeyExtractorProperty.of(new SparkKeyExtractor()); /** * Private constructor. */ private SparkFrontendUtils() { } /** * Derive Spark serializer from a spark context. * * @param sparkContext spark context to derive the serializer from. * @return the serializer. */ public static Serializer deriveSerializerFrom(final org.apache.spark.SparkContext sparkContext) { if (sparkContext.conf().get("spark.serializer", "") .equals("org.apache.spark.serializer.KryoSerializer")) { return new KryoSerializer(sparkContext.conf()); } else { return new JavaSerializer(sparkContext.conf()); } } /** * Collect data by running the DAG. * * @param dag the DAG to execute. * @param loopVertexStack loop vertex stack. * @param lastVertex last vertex added to the dag. * @param serializer serializer for the edges. * @param <T> type of the return data. * @return the data collected. */ public static <T> List<T> collect(final DAG<IRVertex, IREdge> dag, final Stack<LoopVertex> loopVertexStack, final IRVertex lastVertex, final Serializer serializer) { final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>(dag); final IRVertex collectVertex = new OperatorVertex(new CollectTransform<>()); builder.addVertex(collectVertex, loopVertexStack); final IREdge newEdge = new IREdge(getEdgeCommunicationPattern(lastVertex, collectVertex), lastVertex, collectVertex); newEdge.setProperty(EncoderProperty.of(new SparkEncoderFactory(serializer))); newEdge.setProperty(DecoderProperty.of(new SparkDecoderFactory(serializer))); newEdge.setProperty(SPARK_KEY_EXTRACTOR_PROP); builder.connectVertices(newEdge); // launch DAG JobLauncher.launchDAG(new IRDAG(builder.build()), SparkBroadcastVariables.getAll(), ""); return JobLauncher.getCollectedData(); } /** * Retrieve communication pattern of the edge. * * @param src source vertex. * @param dst destination vertex. * @return the communication pattern. */ public static CommunicationPatternProperty.Value getEdgeCommunicationPattern(final IRVertex src, final IRVertex dst) { if (dst instanceof OperatorVertex && (((OperatorVertex) dst).getTransform() instanceof ReduceByKeyTransform || ((OperatorVertex) dst).getTransform() instanceof GroupByKeyTransform)) { return CommunicationPatternProperty.Value.SHUFFLE; } else { return CommunicationPatternProperty.Value.ONE_TO_ONE; } } /** * Converts a {@link Function1} to a corresponding {@link Function}. * <p> * Here, we use the Spark 'JavaSerializer' to facilitate debugging in the future. * TODO #205: RDD Closure with Broadcast Variables Serialization Bug * * @param scalaFunction the scala function to convert. * @param <I> the type of input. * @param <O> the type of output. * @return the converted Java function. */ public static <I, O> Function<I, O> toJavaFunction(final Function1<I, O> scalaFunction) { // This 'JavaSerializer' from Spark provides a human-readable NotSerializableException stack traces, // which can be useful when addressing this problem. // Other toJavaFunction can also use this serializer when debugging. final ClassTag<Function1<I, O>> classTag = ClassTag$.MODULE$.apply(scalaFunction.getClass()); final byte[] serializedFunction = new JavaSerializer().newInstance().serialize(scalaFunction, classTag).array(); return new Function<I, O>() { private Function1<I, O> deserializedFunction; @Override public O call(final I v1) throws Exception { if (deserializedFunction == null) { // TODO #205: RDD Closure with Broadcast Variables Serialization Bug final SerializerInstance js = new JavaSerializer().newInstance(); deserializedFunction = js.deserialize(ByteBuffer.wrap(serializedFunction), classTag); } return deserializedFunction.apply(v1); } }; } /** * Converts a {@link scala.Function2} to a corresponding {@link org.apache.spark.api.java.function.Function2}. * * @param scalaFunction the scala function to convert. * @param <I1> the type of first input. * @param <I2> the type of second input. * @param <O> the type of output. * @return the converted Java function. */ public static <I1, I2, O> Function2<I1, I2, O> toJavaFunction(final scala.Function2<I1, I2, O> scalaFunction) { return new Function2<I1, I2, O>() { @Override public O call(final I1 v1, final I2 v2) throws Exception { return scalaFunction.apply(v1, v2); } }; } /** * Converts a {@link Function1} to a corresponding {@link FlatMapFunction}. * * @param scalaFunction the scala function to convert. * @param <I> the type of input. * @param <O> the type of output. * @return the converted Java function. */ public static <I, O> FlatMapFunction<I, O> toJavaFlatMapFunction( final Function1<I, TraversableOnce<O>> scalaFunction) { return new FlatMapFunction<I, O>() { @Override public Iterator<O> call(final I i) throws Exception { return JavaConverters.asJavaIteratorConverter(scalaFunction.apply(i).toIterator()).asJava(); } }; } /** * Converts a {@link PairFunction} to a plain map {@link Function}. * * @param pairFunction the pair function to convert. * @param <T> the type of original element. * @param <K> the type of converted key. * @param <V> the type of converted value. * @return the converted map function. */ public static <T, K, V> Function<T, Tuple2<K, V>> pairFunctionToPlainFunction( final PairFunction<T, K, V> pairFunction) { return new Function<T, Tuple2<K, V>>() { @Override public Tuple2<K, V> call(final T elem) throws Exception { return pairFunction.call(elem); } }; } }