/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.optimize.listeners; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.io.Serializable; import java.net.InetAddress; import java.util.*; /** * WARNING: THIS LISTENER SHOULD ONLY BE USED FOR MANUAL TESTING PURPOSES<br> * It intentionally causes various types of failures according to some criteria, in order to test the response * to it.<br> * This is useful for example in: * (a) Testing Spark fault tolerance<br> * (b) Testing OOM exception crash dump information<br> * Generally it should not be used in unit tests either, depending on how it is configured.<br> * <br> * Two aspects need to be configured to use this listener: * 1. If/when the "failure" should be triggered - via FailureTrigger classes<br> * 2. The type of failure when triggered - via FailureMode enum<br> * <br> * To specify if/when a failure should be triggered, use a {@link FailureTrigger} instance. Some built-in ones * are provided, random probability, time since initialized, username, and iteration/epoch count. * <br> * Types of failures available:<br> * - OOM (allocate large arrays in loop until OOM).<br> * - System.exit(1)<br> * - IllegalStateException<br> * - Infinite sleep<br> * * @author Alex Black */ @Slf4j public class FailureTestingListener implements TrainingListener, Serializable { public enum FailureMode {OOM, SYSTEM_EXIT_1, ILLEGAL_STATE, INFINITE_SLEEP} public enum CallType {ANY, EPOCH_START, EPOCH_END, FORWARD_PASS, GRADIENT_CALC, BACKWARD_PASS, ITER_DONE} private final FailureTrigger trigger; private final FailureMode failureMode; public FailureTestingListener(@NonNull FailureMode mode, @NonNull FailureTrigger trigger){ this.trigger = trigger; this.failureMode = mode; } @Override public void iterationDone(Model model, int iteration, int epoch) { call(CallType.ITER_DONE, model); } @Override public void onEpochStart(Model model) { call(CallType.EPOCH_START, model); } @Override public void onEpochEnd(Model model) { call(CallType.EPOCH_END, model); } @Override public void onForwardPass(Model model, List<INDArray> activations) { call(CallType.FORWARD_PASS, model); } @Override public void onForwardPass(Model model, Map<String, INDArray> activations) { call(CallType.FORWARD_PASS, model); } @Override public void onGradientCalculation(Model model) { call(CallType.GRADIENT_CALC, model); } @Override public void onBackwardPass(Model model) { call(CallType.BACKWARD_PASS, model); } protected void call(CallType callType, Model model){ if(!trigger.initialized()){ trigger.initialize(); } int iter; int epoch; if(model instanceof MultiLayerNetwork){ iter = ((MultiLayerNetwork) model).getIterationCount(); epoch = ((MultiLayerNetwork) model).getEpochCount(); } else { iter = ((ComputationGraph) model).getIterationCount(); epoch = ((ComputationGraph) model).getEpochCount(); } boolean triggered = trigger.triggerFailure(callType, iter, epoch, model); if(triggered){ log.error("*** FailureTestingListener was triggered on iteration {}, epoch {} - Failure mode is set to {} ***", iter, epoch, failureMode); switch (failureMode){ case OOM: List<INDArray> list = new ArrayList<>(); while(true){ INDArray arr = Nd4j.createUninitialized(1_000_000_000); list.add(arr); } //break; case SYSTEM_EXIT_1: log.error("Exiting due to FailureTestingListener triggering - calling System.exit(1)"); System.exit(1); break; case ILLEGAL_STATE: log.error("Throwing new IllegalStateException due to FailureTestingListener triggering"); throw new IllegalStateException("FailureTestListener was triggered with failure mode " + failureMode + " - iteration " + iter + ", epoch " + epoch); case INFINITE_SLEEP: while(true){ try { Thread.sleep(10000); } catch (InterruptedException e){ //Ignore } } default: throw new RuntimeException("Unknown enum value: " + failureMode); } } } @Data public static abstract class FailureTrigger implements Serializable { private boolean initialized = false; /** * If true: trigger the failure. If false: don't trigger failure * @param callType Type of call * @param iteration Iteration number * @param epoch Epoch number * @param model Model * @return */ public abstract boolean triggerFailure(CallType callType, int iteration, int epoch, Model model); public boolean initialized(){ return initialized; } public void initialize(){ this.initialized = true; } } @AllArgsConstructor public static class And extends FailureTrigger{ protected List<FailureTrigger> triggers; public And(FailureTrigger... triggers){ this.triggers = Arrays.asList(triggers); } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { boolean b = true; for(FailureTrigger ft : triggers) b &= ft.triggerFailure(callType, iteration, epoch, model); return b; } @Override public void initialize(){ super.initialize(); for(FailureTrigger ft : triggers) ft.initialize(); } } public static class Or extends And { public Or(FailureTrigger... triggers) { super(triggers); } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { boolean b = false; for(FailureTrigger ft : triggers) b |= ft.triggerFailure(callType, iteration, epoch, model); return b; } } @Data public static class RandomProb extends FailureTrigger { private final CallType callType; private final double probability; private Random rng; public RandomProb(CallType callType, double probability){ this.callType = callType; this.probability = probability; } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { return (this.callType == CallType.ANY || callType == this.callType) && rng.nextDouble() < probability; } @Override public void initialize(){ super.initialize(); this.rng = new Random(); } } @Data public static class TimeSinceInitializedTrigger extends FailureTrigger { private final long msSinceInit; private long initTime; public TimeSinceInitializedTrigger(long msSinceInit){ this.msSinceInit = msSinceInit; } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { return (System.currentTimeMillis() - initTime) > msSinceInit; } @Override public void initialize(){ super.initialize(); this.initTime = System.currentTimeMillis(); } } @Data public static class UserNameTrigger extends FailureTrigger { private final String userName; private boolean shouldFail = false; public UserNameTrigger(@NonNull String userName) { this.userName = userName; } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { return shouldFail; } @Override public void initialize(){ super.initialize(); shouldFail = this.userName.equalsIgnoreCase(System.getProperty("user.name")); } } //System.out.println("Hostname: " + InetAddress.getLocalHost().getHostName()); @Data public static class HostNameTrigger extends FailureTrigger{ private final String hostName; private boolean shouldFail = false; public HostNameTrigger(@NonNull String hostName) { this.hostName = hostName; } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { return shouldFail; } @Override public void initialize(){ super.initialize(); try { String hostname = InetAddress.getLocalHost().getHostName(); log.info("FailureTestingListere hostname: {}", hostname); shouldFail = this.hostName.equalsIgnoreCase(hostname); } catch (Exception e){ throw new RuntimeException(e); } } } @Data public static class IterationEpochTrigger extends FailureTrigger { private final boolean isEpoch; private final int count; public IterationEpochTrigger(boolean isEpoch, int count){ this.isEpoch = isEpoch; this.count = count; } @Override public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { return (isEpoch && epoch == count) || (!isEpoch && iteration == count); } } }