org.apache.flink.api.common.functions.RichMapPartitionFunction Java Examples

The following examples show how to use org.apache.flink.api.common.functions.RichMapPartitionFunction. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: ExecutionEnvironmentITCase.java    From Flink-CEPplus with Apache License 2.0 6 votes vote down vote up
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
	env.setParallelism(ExecutionConfig.PARALLELISM_AUTO_MAX);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
Example #2
Source File: ExecutionEnvironmentITCase.java    From flink with Apache License 2.0 6 votes vote down vote up
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
Example #3
Source File: ExecutionEnvironmentITCase.java    From flink with Apache License 2.0 6 votes vote down vote up
/**
 * Ensure that the user can pass a custom configuration object to the LocalEnvironment.
 */
@Test
public void testLocalEnvironmentWithConfig() throws Exception {
	Configuration conf = new Configuration();
	conf.setInteger(TaskManagerOptions.NUM_TASK_SLOTS, PARALLELISM);

	final ExecutionEnvironment env = ExecutionEnvironment.createLocalEnvironment(conf);

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(PARALLELISM, resultCollection.size());
}
 
Example #4
Source File: FunctionCompiler.java    From rheem with Apache License 2.0 6 votes vote down vote up
public <I, O> RichMapPartitionFunction<I, O> compile(MapPartitionsDescriptor<I, O> descriptor, FlinkExecutionContext fex){
    FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>> function =
            (FunctionDescriptor.ExtendedSerializableFunction<Iterable<I>, Iterable<O>>)
                    descriptor.getJavaImplementation();
    return new RichMapPartitionFunction<I, O>() {
        @Override
        public void mapPartition(Iterable<I> iterable, Collector<O> collector) throws Exception {
            function.apply(iterable).forEach(
                    element -> {
                        collector.collect(element);
                    }
            );

        }
        @Override
        public void open(Configuration parameters) throws Exception {
            function.open(fex);
        }
    };
}
 
Example #5
Source File: DataSetUtils.java    From Flink-CEPplus with Apache License 2.0 5 votes vote down vote up
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
Example #6
Source File: RemoteEnvironmentITCase.java    From Flink-CEPplus with Apache License 2.0 5 votes vote down vote up
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
Example #7
Source File: DataSetUtils.java    From Flink-CEPplus with Apache License 2.0 5 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
Example #8
Source File: DataSetUtils.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
Example #9
Source File: RemoteEnvironmentITCase.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);
	env.getConfig().disableSysoutLogging();

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
Example #10
Source File: DataSetUtils.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
Example #11
Source File: DataSetUtils.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set as described below.
 * <ul>
 *  <li> a map function is applied to the input data set
 *  <li> each map task holds a counter c which is increased for each record
 *  <li> c is shifted by n bits where n = log2(number of parallel tasks)
 * 	<li> to create a unique ID among all tasks, the task id is added to the counter
 * 	<li> for each record, the resulting counter is collected
 * </ul>
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithUniqueId (DataSet <T> input) {

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long maxBitSize = getBitSize(Long.MAX_VALUE);
		long shifter = 0;
		long start = 0;
		long taskId = 0;
		long label = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);
			shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
			taskId = getRuntimeContext().getIndexOfThisSubtask();
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value : values) {
				label = (start << shifter) + taskId;

				if (getBitSize(start) + shifter < maxBitSize) {
					out.collect(new Tuple2<>(label, value));
					start++;
				} else {
					throw new Exception("Exceeded Long value range while generating labels");
				}
			}
		}
	});
}
 
Example #12
Source File: DataSetUtils.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Method that goes over all the elements in each partition in order to retrieve
 * the total number of elements.
 *
 * @param input the DataSet received as input
 * @return a data set containing tuples of subtask index, number of elements mappings.
 */
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
			long counter = 0;
			for (T value : values) {
				counter++;
			}
			out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
		}
	});
}
 
Example #13
Source File: RemoteEnvironmentITCase.java    From flink with Apache License 2.0 5 votes vote down vote up
/**
 * Ensure that the program parallelism can be set even if the configuration is supplied.
 */
@Test
public void testUserSpecificParallelism() throws Exception {
	Configuration config = new Configuration();
	config.setString(AkkaOptions.STARTUP_TIMEOUT, VALID_STARTUP_TIMEOUT);

	final URI restAddress = MINI_CLUSTER_RESOURCE.getRestAddres();
	final String hostname = restAddress.getHost();
	final int port = restAddress.getPort();

	final ExecutionEnvironment env = ExecutionEnvironment.createRemoteEnvironment(
			hostname,
			port,
			config
	);
	env.setParallelism(USER_DOP);

	DataSet<Integer> result = env.createInput(new ParallelismDependentInputFormat())
			.rebalance()
			.mapPartition(new RichMapPartitionFunction<Integer, Integer>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Integer> out) throws Exception {
					out.collect(getRuntimeContext().getIndexOfThisSubtask());
				}
			});
	List<Integer> resultCollection = result.collect();
	assertEquals(USER_DOP, resultCollection.size());
}
 
Example #14
Source File: PartitionMapOperatorTest.java    From flink with Apache License 2.0 4 votes vote down vote up
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}
 
Example #15
Source File: DataSetUtils.java    From flink with Apache License 2.0 4 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}
 
Example #16
Source File: ParallelPrefixSpan.java    From Alink with Apache License 2.0 4 votes vote down vote up
/**
 * Generate frequent sequence patterns using PrefixSpan algorithm.
 *
 * @return Frequent sequence patterns and their supports.
 */
public DataSet<Tuple2<int[], Integer>> run() {
    final int parallelism = BatchOperator.getExecutionEnvironmentFromDataSets(sequences).getParallelism();
    DataSet<Tuple2<Integer, int[]>> partitionedSequence = partitionSequence(sequences, itemCounts, parallelism);
    final int maxLength = maxPatternLength;

    return partitionedSequence
        .partitionCustom(new Partitioner<Integer>() {
            @Override
            public int partition(Integer key, int numPartitions) {
                return key % numPartitions;
            }
        }, 0)
        .mapPartition(new RichMapPartitionFunction<Tuple2<Integer, int[]>, Tuple2<int[], Integer>>() {
            @Override
            public void mapPartition(Iterable<Tuple2<Integer, int[]>> values,
                                     Collector<Tuple2<int[], Integer>> out) throws Exception {
                List<Long> bc1 = getRuntimeContext().getBroadcastVariable("minSupportCnt");
                List<Tuple2<Integer, Integer>> bc2 = getRuntimeContext().getBroadcastVariable("itemCounts");
                int taskId = getRuntimeContext().getIndexOfThisSubtask();

                long minSuppCnt = bc1.get(0);
                List<int[]> allSeq = new ArrayList<>();
                values.forEach(t -> allSeq.add(t.f1));

                List<Postfix> initialPostfixes = new ArrayList<>(allSeq.size());
                for (int i = 0; i < allSeq.size(); i++) {
                    initialPostfixes.add(new Postfix(i));
                }

                bc2.forEach(itemCount -> {
                    int item = itemCount.f0;
                    if (item % parallelism == taskId) {
                        generateFreqPattern(allSeq, initialPostfixes, item, minSuppCnt, maxLength, out);
                    }
                });
            }
        })
        .withBroadcastSet(this.minSupportCnt, "minSupportCnt")
        .withBroadcastSet(this.itemCounts, "itemCounts")
        .name("generate_freq_pattern");
}
 
Example #17
Source File: AlsTrainBatchOp.java    From Alink with Apache License 2.0 4 votes vote down vote up
/**
 * Matrix decomposition using ALS algorithm.
 *
 * @param inputs a dataset of user-item-rating tuples
 * @return user factors and item factors.
 */
@Override
public AlsTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);

    final String userColName = getUserCol();
    final String itemColName = getItemCol();
    final String rateColName = getRateCol();

    final double lambda = getLambda();
    final int rank = getRank();
    final int numIter = getNumIter();
    final boolean nonNegative = getNonnegative();
    final boolean implicitPrefs = getImplicitPrefs();
    final double alpha = getAlpha();
    final int numMiniBatches = getNumBlocks();

    final int userColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), userColName);
    final int itemColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), itemColName);
    final int rateColIdx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), rateColName);

    // tuple3: userId, itemId, rating
    DataSet<Tuple3<Long, Long, Float>> alsInput = in.getDataSet()
        .map(new MapFunction<Row, Tuple3<Long, Long, Float>>() {
            @Override
            public Tuple3<Long, Long, Float> map(Row value) {
                return new Tuple3<>(((Number) value.getField(userColIdx)).longValue(),
                    ((Number) value.getField(itemColIdx)).longValue(),
                    ((Number) value.getField(rateColIdx)).floatValue());
            }
        });

    AlsTrain als = new AlsTrain(rank, numIter, lambda, implicitPrefs, alpha, numMiniBatches, nonNegative);
    DataSet<Tuple3<Byte, Long, float[]>> factors = als.fit(alsInput);

    DataSet<Row> output = factors.mapPartition(new RichMapPartitionFunction<Tuple3<Byte, Long, float[]>, Row>() {
        @Override
        public void mapPartition(Iterable<Tuple3<Byte, Long, float[]>> values, Collector<Row> out) {
            new AlsModelDataConverter(userColName, itemColName).save(values, out);
        }
    });

    this.setOutput(output, new AlsModelDataConverter(userColName, itemColName).getModelSchema());
    return this;
}
 
Example #18
Source File: BaseTuning.java    From Alink with Apache License 2.0 4 votes vote down vote up
private DataSet<Tuple2<Integer, Row>> split(BatchOperator<?> data, int k) {

		DataSet<Row> input = shuffle(data.getDataSet());

		DataSet<Tuple2<Integer, Long>> counts = DataSetUtils.countElementsPerPartition(input);

		return input
			.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Row>>() {
				long taskStart = 0L;
				long totalNumInstance = 0L;

				@Override
				public void open(Configuration parameters) throws Exception {
					List<Tuple2<Integer, Long>> counts1 = getRuntimeContext().getBroadcastVariable("counts");

					int taskId = getRuntimeContext().getIndexOfThisSubtask();
					for (Tuple2<Integer, Long> cnt : counts1) {

						if (taskId < cnt.f0) {
							taskStart += cnt.f1;
						}

						totalNumInstance += cnt.f1;
					}
				}

				@Override
				public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) throws Exception {
					DistributedInfo distributedInfo = new DefaultDistributedInfo();
					Tuple2<Integer, Long> split1 = new Tuple2<>(-1, -1L);
					long lcnt = taskStart;

					for (int i = 0; i <= k; ++i) {
						long sp = distributedInfo.startPos(i, k, totalNumInstance);
						long lrc = distributedInfo.localRowCnt(i, k, totalNumInstance);

						if (taskStart < sp) {
							split1.f0 = i - 1;
							split1.f1 = distributedInfo.startPos(i - 1, k, totalNumInstance)
								+ distributedInfo.localRowCnt(i - 1, k, totalNumInstance);

							break;
						}

						if (taskStart == sp) {
							split1.f0 = i;
							split1.f1 = sp + lrc;

							break;
						}
					}

					for (Row val : values) {

						if (lcnt >= split1.f1) {
							split1.f0 += 1;
							split1.f1 = distributedInfo.localRowCnt(split1.f0, k, totalNumInstance) + lcnt;
						}

						out.collect(Tuple2.of(split1.f0, val));
						lcnt++;
					}
				}
			}).withBroadcastSet(counts, "counts");
	}
 
Example #19
Source File: PartitionMapOperatorTest.java    From flink with Apache License 2.0 4 votes vote down vote up
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}
 
Example #20
Source File: DataSetUtils.java    From flink with Apache License 2.0 4 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}
 
Example #21
Source File: PartitionMapOperatorTest.java    From Flink-CEPplus with Apache License 2.0 4 votes vote down vote up
@Test
public void testMapPartitionWithRuntimeContext() {
	try {
		final String taskName = "Test Task";
		final AtomicBoolean opened = new AtomicBoolean();
		final AtomicBoolean closed = new AtomicBoolean();
		
		final MapPartitionFunction<String, Integer> parser = new RichMapPartitionFunction<String, Integer>() {
			
			@Override
			public void open(Configuration parameters) throws Exception {
				opened.set(true);
				RuntimeContext ctx = getRuntimeContext();
				assertEquals(0, ctx.getIndexOfThisSubtask());
				assertEquals(1, ctx.getNumberOfParallelSubtasks());
				assertEquals(taskName, ctx.getTaskName());
			}
			
			@Override
			public void mapPartition(Iterable<String> values, Collector<Integer> out) {
				for (String s : values) {
					out.collect(Integer.parseInt(s));
				}
			}
			
			@Override
			public void close() throws Exception {
				closed.set(true);
			}
		};
		
		MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String, Integer>> op = 
				new MapPartitionOperatorBase<String, Integer, MapPartitionFunction<String,Integer>>(
				parser, new UnaryOperatorInformation<String, Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO), taskName);
		
		List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));

		final TaskInfo taskInfo = new TaskInfo(taskName, 1, 0, 1, 0);

		ExecutionConfig executionConfig = new ExecutionConfig();
		executionConfig.disableObjectReuse();
		
		List<Integer> resultMutableSafe = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		executionConfig.enableObjectReuse();
		List<Integer> resultRegular = op.executeOnCollections(input,
				new RuntimeUDFContext(taskInfo, null, executionConfig,
						new HashMap<String, Future<Path>>(),
						new HashMap<String, Accumulator<?, ?>>(),
						new UnregisteredMetricsGroup()),
				executionConfig);
		
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
		assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
		
		assertTrue(opened.get());
		assertTrue(closed.get());
	}
	catch (Exception e) {
		e.printStackTrace();
		fail(e.getMessage());
	}
}
 
Example #22
Source File: DataSetUtils.java    From Flink-CEPplus with Apache License 2.0 4 votes vote down vote up
/**
 * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
 * consecutive.
 *
 * @param input the input data set
 * @return a data set of tuple 2 consisting of consecutive ids and initial values.
 */
public static <T> DataSet<Tuple2<Long, T>> zipWithIndex(DataSet<T> input) {

	DataSet<Tuple2<Integer, Long>> elementCount = countElementsPerPartition(input);

	return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {

		long start = 0;

		@Override
		public void open(Configuration parameters) throws Exception {
			super.open(parameters);

			List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
					"counts",
					new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
						@Override
						public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
							// sort the list by task id to calculate the correct offset
							List<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
							for (Tuple2<Integer, Long> datum : data) {
								sortedData.add(datum);
							}
							Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
								@Override
								public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
									return o1.f0.compareTo(o2.f0);
								}
							});
							return sortedData;
						}
					});

			// compute the offset for each partition
			for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
				start += offsets.get(i).f1;
			}
		}

		@Override
		public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
			for (T value: values) {
				out.collect(new Tuple2<>(start++, value));
			}
		}
	}).withBroadcastSet(elementCount, "counts");
}