/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.reef.tests.group; import org.apache.reef.annotations.audience.DriverSide; import org.apache.reef.driver.context.ActiveContext; import org.apache.reef.driver.evaluator.AllocatedEvaluator; import org.apache.reef.driver.evaluator.EvaluatorRequest; import org.apache.reef.driver.evaluator.EvaluatorRequestor; import org.apache.reef.driver.task.CompletedTask; import org.apache.reef.driver.task.RunningTask; import org.apache.reef.driver.task.TaskConfiguration; import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; import org.apache.reef.io.network.group.api.driver.GroupCommDriver; import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; import org.apache.reef.io.serialization.SerializableCodec; import org.apache.reef.tang.Configuration; import org.apache.reef.tang.annotations.Name; import org.apache.reef.tang.annotations.NamedParameter; import org.apache.reef.tang.annotations.Unit; import org.apache.reef.wake.EventHandler; import org.apache.reef.wake.time.event.StartTime; import javax.inject.Inject; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; /** * Driver used for testing multiple communication groups. */ @DriverSide @Unit public final class MultipleCommGroupsDriver { private static final Logger LOG = Logger.getLogger(MultipleCommGroupsDriver.class.getName()); private final EvaluatorRequestor requestor; private final GroupCommDriver groupCommDriver; private final String[][] taskIds; private final AtomicInteger[] taskCounter; private final List<CommunicationGroupDriver> commGroupDriverList; private final List<ActiveContext> activeContextsToBeHandled; @Inject private MultipleCommGroupsDriver(final EvaluatorRequestor requestor, final GroupCommDriver groupCommDriver) { this.requestor = requestor; this.groupCommDriver = groupCommDriver; taskIds = new String[][]{ {"MasterTask-1", "SlaveTask-1-1", "SlaveTask-1-2", "SlaveTask-1-3"}, {"MasterTask-2", "SlaveTask-2-1"} }; taskCounter = new AtomicInteger[]{new AtomicInteger(0), new AtomicInteger(0)}; commGroupDriverList = new ArrayList<>(2); activeContextsToBeHandled = new ArrayList<>(2); initializeCommGroups(); } private void initializeCommGroups() { commGroupDriverList.add(groupCommDriver.newCommunicationGroup(Group1.class, 4)); commGroupDriverList.add(groupCommDriver.newCommunicationGroup(Group2.class, 2)); commGroupDriverList.get(0).addBroadcast(BroadcastOperatorName.class, BroadcastOperatorSpec.newBuilder() .setSenderId(taskIds[0][0]) .setDataCodecClass(SerializableCodec.class) .build()); commGroupDriverList.get(1).addBroadcast(BroadcastOperatorName.class, BroadcastOperatorSpec.newBuilder() .setSenderId(taskIds[1][0]) .setDataCodecClass(SerializableCodec.class) .build()); } final class StartHandler implements EventHandler<StartTime> { @Override public void onNext(final StartTime startTime) { requestor.submit(EvaluatorRequest.newBuilder() .setNumber(4) .setMemory(128) .build()); } } final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> { @Override public void onNext(final AllocatedEvaluator allocatedEvaluator) { LOG.log(Level.INFO, "Evaluator allocated {0}", allocatedEvaluator); allocatedEvaluator.submitContextAndService( groupCommDriver.getContextConfiguration(), groupCommDriver.getServiceConfiguration()); } } final class ContextActiveHandler implements EventHandler<ActiveContext> { private final AtomicInteger contextCounter = new AtomicInteger(0); @Override public void onNext(final ActiveContext activeContext) { final int count = contextCounter.getAndIncrement(); if (count <= 1) { LOG.log(Level.INFO, "{0} will be handled after tasks in Group1 started", activeContext); activeContextsToBeHandled.add(activeContext); } else { // Add task to Group1 submitTask(activeContext, 0); } } } final class TaskRunningHandler implements EventHandler<RunningTask> { private final AtomicInteger runningTaskCounter = new AtomicInteger(0); @Override public void onNext(final RunningTask runningTask) { LOG.log(Level.INFO, "{0} has started", runningTask); final int count = runningTaskCounter.getAndIncrement(); // After two tasks has started, submit tasks to the active contexts in activeContextsToBeHandled if (count == 1) { for (final ActiveContext activeContext : activeContextsToBeHandled) { // Add task to Group2 submitTask(activeContext, 1); } } } } private void submitTask(final ActiveContext activeContext, final int groupIndex) { final String taskId = taskIds[groupIndex][taskCounter[groupIndex].getAndIncrement()]; LOG.log(Level.INFO, "Got active context {0}. Submit {1}", new Object[]{activeContext, taskId}); final Configuration partialTaskConf; if (taskId.equals(taskIds[groupIndex][0])) { partialTaskConf = TaskConfiguration.CONF .set(TaskConfiguration.IDENTIFIER, taskId) .set(TaskConfiguration.TASK, MasterTask.class) .build(); } else { partialTaskConf = TaskConfiguration.CONF .set(TaskConfiguration.IDENTIFIER, taskId) .set(TaskConfiguration.TASK, SlaveTask.class) .build(); } commGroupDriverList.get(groupIndex).addTask(partialTaskConf); activeContext.submitTask(groupCommDriver.getTaskConfiguration(partialTaskConf)); } final class TaskCompletedHandler implements EventHandler<CompletedTask> { private final AtomicInteger completedTaskCounter = new AtomicInteger(0); @Override public void onNext(final CompletedTask completedTask) { final int count = completedTaskCounter.getAndIncrement(); LOG.log(Level.INFO, "{0} has completed.", completedTask); if (count <= 1) { // Add task to Group1 submitTask(completedTask.getActiveContext(), 0); } else { completedTask.getActiveContext().close(); } } } @NamedParameter() final class Group1 implements Name<String> { } @NamedParameter() final class Group2 implements Name<String> { } @NamedParameter() final class BroadcastOperatorName implements Name<String> { } }