/* * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.tensorflow.framework.optimizers; import java.util.List; import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.tools.Shape; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** * Optimizer that implements the Adam algorithm. * <p> * See the <a href="http://arxiv.org/abs/1412.6980">paper</a>. */ @Operator public class Adam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; private final float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; private Constant<TFloat32> learningRateConst; private Constant<TFloat32> epsilonConst; private Constant<TFloat32> betaOneConst; private Constant<TFloat32> betaTwoConst; private Variable<TFloat32> betaOnePower; private Variable<TFloat32> betaTwoPower; public Adam(Graph graph, float learningRate) { this(graph, learningRate, 0.9f, 0.999f, 1e-8f); } public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph); this.learningRate = learningRate; this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; } public Adam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, 0.9f, 0.999f, 1e-8f); } public Adam(Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph, name); this.learningRate = learningRate; this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; } @Endpoint(name = "adam_minimize") public static <T extends TType> Op createAdamMinimize(Scope scope, Operand<T> loss, float learningRate, float betaOne, float betaTwo, float epsilon, Optimizer.Options... options) { if (!(scope.env() instanceof Graph)) { throw new IllegalArgumentException("Optimizers are only supported on Graphs"); } Adam adam = new Adam((Graph) scope.env(), learningRate, betaOne, betaTwo, epsilon); String name = null; for (Options o : options) { if (o.sharedName != null) { name = o.sharedName; } } if (name == null) { return adam.minimize(loss); } else { return adam.minimize(loss, name); } } @Override protected void createSlots(List<Output<? extends TType>> variables) { for (Output<? extends TType> v : variables) { createAdamSlot(v.asOutput()); } betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); Assign<TFloat32> betaOnePowerInit = tf .assign(betaOnePower, tf.constant(betaOne)); graph.addInitializer(betaOnePowerInit); betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); Assign<TFloat32> betaTwoPowerInit = tf .assign(betaTwoPower, tf.constant(betaTwo)); graph.addInitializer(betaTwoPowerInit); } @Override protected Optional<Op> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); } private <T extends TType> void createAdamSlot(Output<T> v) { Operand<T> firstMomentInitializer = tf .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand<T> secondMomentInitializer = tf .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @Override protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) { Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get(); Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get(); return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower, gradient.dataType()), tf.dtypes.cast(betaTwoPower, gradient.dataType()), tf.dtypes.cast(learningRateConst, gradient.dataType()), tf.dtypes.cast(betaOneConst, gradient.dataType()), tf.dtypes.cast(betaTwoConst, gradient.dataType()), tf.dtypes.cast(epsilonConst, gradient.dataType()), gradient); } /** * Gathers up the update operations into a single op that can be used as a run target. * <p> * Adds the betaOne and betaTwo updates to the end of the updates list. * * @param updateOperations The update operations. * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ @Override protected Op finish(List<Op> updateOperations, String name) { updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); return super.finish(updateOperations, name); } @Override public String toString() { return "Adam{" + "learningRate=" + learningRate + ", betaOne=" + betaOne + ", betaTwo=" + betaTwo + ", epsilon=" + epsilon + '}'; } @Override public String getOptimizerName() { return "Adam"; } }