/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package org.platanios.tensorflow.api.ops.training.optimizers import org.platanios.tensorflow.api._ import org.platanios.tensorflow.api.ops.variables.Variable import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers /** * @author Emmanouil Antonios Platanios */ class GradientDescentSpec extends AnyFlatSpec with Matchers { "Gradient descent" must "work for dense updates to resource-based variables" in { val value0 = Tensor[Double](1.0, 2.0) val value1 = Tensor[Double](3.0, 4.0) val updatedValue0 = Tensor[Double](1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1) val updatedValue1 = Tensor[Double](3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01) val graph = Graph() val (variable0, variable1, gdOp) = tf.createWith(graph) { val variable0 = tf.variable[Double]("v0", Shape(2), tf.ConstantInitializer(Tensor(1, 2))) val variable1 = tf.variable[Double]("v1", Shape(2), tf.ConstantInitializer(Tensor(3, 4))) val gradient0 = tf.constant(Tensor[Double](0.1, 0.1)) val gradient1 = tf.constant(Tensor[Double](0.01, 0.01)) val gdOp = GradientDescent(3.0f).applyGradients(Seq( (gradient0, variable0.asInstanceOf[Variable[Any]]), (gradient1, variable1.asInstanceOf[Variable[Any]]))) (variable0.value, variable1.value, gdOp) } val session = Session(graph) session.run(targets = graph.trainableVariablesInitializer()) var variable0Value = session.run(fetches = variable0) var variable1Value = session.run(fetches = variable1) // TODO: !!! ??? [TENSORS] // assert(variable0Value === value0 +- 1e-6) // assert(variable1Value === value1 +- 1e-6) session.run(targets = gdOp) variable0Value = session.run(fetches = variable0) variable1Value = session.run(fetches = variable1) // assert(variable0Value === updatedValue0 +- 1e-6) // assert(variable1Value === updatedValue1 +- 1e-6) } }