/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.reef.examples.group.bgd; import org.apache.reef.annotations.audience.DriverSide; import org.apache.reef.driver.context.ActiveContext; import org.apache.reef.driver.context.ServiceConfiguration; import org.apache.reef.driver.task.CompletedTask; import org.apache.reef.driver.task.FailedTask; import org.apache.reef.driver.task.RunningTask; import org.apache.reef.driver.task.TaskConfiguration; import org.apache.reef.evaluator.context.parameters.ContextIdentifier; import org.apache.reef.examples.group.bgd.data.parser.Parser; import org.apache.reef.examples.group.bgd.data.parser.SVMLightParser; import org.apache.reef.examples.group.bgd.loss.LossFunction; import org.apache.reef.examples.group.bgd.operatornames.*; import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; import org.apache.reef.examples.group.bgd.parameters.BGDControlParameters; import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure; import org.apache.reef.io.data.loading.api.DataLoadingService; import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; import org.apache.reef.io.network.group.api.driver.GroupCommDriver; import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; import org.apache.reef.io.serialization.Codec; import org.apache.reef.io.serialization.SerializableCodec; import org.apache.reef.poison.PoisonedConfiguration; import org.apache.reef.tang.Configuration; import org.apache.reef.tang.Configurations; import org.apache.reef.tang.Tang; import org.apache.reef.tang.annotations.Unit; import org.apache.reef.tang.exceptions.InjectionException; import org.apache.reef.tang.formats.ConfigurationSerializer; import org.apache.reef.wake.EventHandler; import javax.inject.Inject; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; /** * Driver for BGD example. */ @DriverSide @Unit public class BGDDriver { private static final Logger LOG = Logger.getLogger(BGDDriver.class.getName()); private static final Tang TANG = Tang.Factory.getTang(); private static final double STARTUP_FAILURE_PROB = 0.01; private final DataLoadingService dataLoadingService; private final GroupCommDriver groupCommDriver; private final ConfigurationSerializer confSerializer; private final CommunicationGroupDriver communicationsGroup; private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); private final AtomicInteger slaveIds = new AtomicInteger(0); private final Map<String, RunningTask> runningTasks = new HashMap<>(); private final AtomicBoolean jobComplete = new AtomicBoolean(false); private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>(); private final BGDControlParameters bgdControlParameters; private String communicationsGroupMasterContextId; @Inject public BGDDriver(final DataLoadingService dataLoadingService, final GroupCommDriver groupCommDriver, final ConfigurationSerializer confSerializer, final BGDControlParameters bgdControlParameters) { this.dataLoadingService = dataLoadingService; this.groupCommDriver = groupCommDriver; this.confSerializer = confSerializer; this.bgdControlParameters = bgdControlParameters; final int minNumOfPartitions = bgdControlParameters.isRampup() ? bgdControlParameters.getMinParts() : dataLoadingService.getNumberOfPartitions(); final int numParticipants = minNumOfPartitions + 1; this.communicationsGroup = this.groupCommDriver.newCommunicationGroup( AllCommunicationGroup.class, // NAME numParticipants); // Number of participants LOG.log(Level.INFO, "Obtained entire communication group: start with {0} partitions", numParticipants); this.communicationsGroup .addBroadcast(ControlMessageBroadcaster.class, BroadcastOperatorSpec.newBuilder() .setSenderId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .build()) .addBroadcast(ModelBroadcaster.class, BroadcastOperatorSpec.newBuilder() .setSenderId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .build()) .addReduce(LossAndGradientReducer.class, ReduceOperatorSpec.newBuilder() .setReceiverId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .setReduceFunctionClass(LossAndGradientReduceFunction.class) .build()) .addBroadcast(ModelAndDescentDirectionBroadcaster.class, BroadcastOperatorSpec.newBuilder() .setSenderId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .build()) .addBroadcast(DescentDirectionBroadcaster.class, BroadcastOperatorSpec.newBuilder() .setSenderId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .build()) .addReduce(LineSearchEvaluationsReducer.class, ReduceOperatorSpec.newBuilder() .setReceiverId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .setReduceFunctionClass(LineSearchReduceFunction.class) .build()) .addBroadcast(MinEtaBroadcaster.class, BroadcastOperatorSpec.newBuilder() .setSenderId(MasterTask.TASK_ID) .setDataCodecClass(SerializableCodec.class) .build()) .finalise(); LOG.log(Level.INFO, "Added operators to communicationsGroup"); } final class ContextActiveHandler implements EventHandler<ActiveContext> { @Override public void onNext(final ActiveContext activeContext) { LOG.log(Level.INFO, "Got active context: {0}", activeContext.getId()); if (jobRunning(activeContext)) { if (!groupCommDriver.isConfigured(activeContext)) { // The Context is not configured with the group communications service let's do that. submitGroupCommunicationsService(activeContext); } else { // The group communications service is already active on this context. We can submit the task. submitTask(activeContext); } } } /** * @param activeContext a context to be configured with group communications. */ private void submitGroupCommunicationsService(final ActiveContext activeContext) { final Configuration contextConf = groupCommDriver.getContextConfiguration(); final String contextId = getContextId(contextConf); final Configuration serviceConf; if (!dataLoadingService.isDataLoadedContext(activeContext)) { communicationsGroupMasterContextId = contextId; serviceConf = groupCommDriver.getServiceConfiguration(); } else { final Configuration parsedDataServiceConf = ServiceConfiguration.CONF .set(ServiceConfiguration.SERVICES, ExampleList.class) .build(); serviceConf = Tang.Factory.getTang() .newConfigurationBuilder(groupCommDriver.getServiceConfiguration(), parsedDataServiceConf) .bindImplementation(Parser.class, SVMLightParser.class) .build(); } LOG.log(Level.FINEST, "Submit GCContext conf: {0} and Service conf: {1}", new Object[]{ confSerializer.toString(contextConf), confSerializer.toString(serviceConf)}); activeContext.submitContextAndService(contextConf, serviceConf); } private void submitTask(final ActiveContext activeContext) { assert groupCommDriver.isConfigured(activeContext); final Configuration partialTaskConfiguration; if (activeContext.getId().equals(communicationsGroupMasterContextId) && !masterTaskSubmitted()) { partialTaskConfiguration = getMasterTaskConfiguration(); LOG.info("Submitting MasterTask conf"); } else { partialTaskConfiguration = getSlaveTaskConfiguration(getSlaveId(activeContext)); // partialTaskConfiguration = Configurations.merge( // getSlaveTaskConfiguration(getSlaveId(activeContext)), // getTaskPoisonConfiguration()); LOG.info("Submitting SlaveTask conf"); } communicationsGroup.addTask(partialTaskConfiguration); final Configuration taskConfiguration = groupCommDriver.getTaskConfiguration(partialTaskConfiguration); LOG.log(Level.FINEST, "{0}", confSerializer.toString(taskConfiguration)); activeContext.submitTask(taskConfiguration); } private boolean jobRunning(final ActiveContext activeContext) { synchronized (runningTasks) { if (!jobComplete.get()) { return true; } else { LOG.log(Level.INFO, "Job complete. Not submitting any task. Closing context {0}", activeContext); activeContext.close(); return false; } } } } final class TaskRunningHandler implements EventHandler<RunningTask> { @Override public void onNext(final RunningTask runningTask) { synchronized (runningTasks) { if (!jobComplete.get()) { LOG.log(Level.INFO, "Job has not completed yet. Adding to runningTasks: {0}", runningTask); runningTasks.put(runningTask.getId(), runningTask); } else { LOG.log(Level.INFO, "Job complete. Closing context: {0}", runningTask.getActiveContext().getId()); runningTask.getActiveContext().close(); } } } } final class TaskFailedHandler implements EventHandler<FailedTask> { @Override public void onNext(final FailedTask failedTask) { final String failedTaskId = failedTask.getId(); LOG.log(Level.WARNING, "Got failed Task: " + failedTaskId); if (jobRunning(failedTaskId)) { final ActiveContext activeContext = failedTask.getActiveContext().get(); final Configuration partialTaskConf = getSlaveTaskConfiguration(failedTaskId); // Do not add the task back: // allCommGroup.addTask(partialTaskConf); final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); LOG.log(Level.FINEST, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); activeContext.submitTask(taskConf); } } private boolean jobRunning(final String failedTaskId) { synchronized (runningTasks) { if (!jobComplete.get()) { return true; } else { final RunningTask rTask = runningTasks.remove(failedTaskId); LOG.log(Level.INFO, "Job has completed. Not resubmitting"); if (rTask != null) { LOG.log(Level.INFO, "Closing activecontext"); rTask.getActiveContext().close(); } else { LOG.log(Level.INFO, "Master must have closed my context"); } return false; } } } } final class TaskCompletedHandler implements EventHandler<CompletedTask> { @Override public void onNext(final CompletedTask task) { LOG.log(Level.INFO, "Got CompletedTask: {0}", task.getId()); final byte[] retVal = task.get(); if (retVal != null) { final List<Double> losses = BGDDriver.this.lossCodec.decode(retVal); for (final Double loss : losses) { LOG.log(Level.INFO, "OUT: LOSS = {0}", loss); } } synchronized (runningTasks) { LOG.log(Level.INFO, "Acquired lock on runningTasks. Removing {0}", task.getId()); final RunningTask rTask = runningTasks.remove(task.getId()); if (rTask != null) { LOG.log(Level.INFO, "Closing active context: {0}", task.getActiveContext().getId()); task.getActiveContext().close(); } else { LOG.log(Level.INFO, "Master must have closed active context already for task {0}", task.getId()); } if (MasterTask.TASK_ID.equals(task.getId())) { jobComplete.set(true); LOG.log(Level.INFO, "Master(=>Job) complete. Closing other running tasks: {0}", runningTasks.values()); for (final RunningTask runTask : runningTasks.values()) { runTask.getActiveContext().close(); } LOG.finest("Clearing runningTasks"); runningTasks.clear(); } } } } /** * @return Configuration for the MasterTask */ public Configuration getMasterTaskConfiguration() { return Configurations.merge( TaskConfiguration.CONF .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) .set(TaskConfiguration.TASK, MasterTask.class) .build(), bgdControlParameters.getConfiguration()); } /** * @return Configuration for the SlaveTask */ private Configuration getSlaveTaskConfiguration(final String taskId) { final double pSuccess = bgdControlParameters.getProbOfSuccessfulIteration(); final int numberOfPartitions = dataLoadingService.getNumberOfPartitions(); final double pFailure = 1 - Math.pow(pSuccess, 1.0 / numberOfPartitions); return Tang.Factory.getTang() .newConfigurationBuilder( TaskConfiguration.CONF .set(TaskConfiguration.IDENTIFIER, taskId) .set(TaskConfiguration.TASK, SlaveTask.class) .build()) .bindNamedParameter(ModelDimensions.class, "" + bgdControlParameters.getDimensions()) .bindImplementation(LossFunction.class, bgdControlParameters.getLossFunction()) .bindNamedParameter(ProbabilityOfFailure.class, Double.toString(pFailure)) .build(); } private Configuration getTaskPoisonConfiguration() { return PoisonedConfiguration.TASK_CONF .set(PoisonedConfiguration.CRASH_PROBABILITY, STARTUP_FAILURE_PROB) .set(PoisonedConfiguration.CRASH_TIMEOUT, 1) .build(); } private String getContextId(final Configuration contextConf) { try { return TANG.newInjector(contextConf).getNamedInstance(ContextIdentifier.class); } catch (final InjectionException e) { throw new RuntimeException("Unable to inject context identifier from context conf", e); } } private String getSlaveId(final ActiveContext activeContext) { return "SlaveTask-" + slaveIds.getAndIncrement(); } private boolean masterTaskSubmitted() { return !masterSubmitted.compareAndSet(false, true); } }