/*
 * Copyright 2017 LinkedIn Corp. All rights reserved.
 * Licensed under the Apache License, Version 2.0 (the "License"); you may
 * not use this file except in compliance with the License. You may obtain a
 * copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package com.linkedin.photon.ml.hyperparameter.criteria

import breeze.linalg.DenseVector
import breeze.numerics.sqrt
import breeze.stats.distributions.Gaussian
import com.linkedin.photon.ml.hyperparameter.estimators.PredictionTransformation

/**
 * Expected improvement selection criterion. This transformation produces the expected improvement of the model
 * predictions (over the current "best" value).
 *
 * @see "Practical Bayesian Optimization of Machine Learning Algorithms" (PBO),
 *   https://papers.nips.cc/paper/4522-practical-bayesian-optimization-of-machine-learning-algorithms.pdf
 *
 * @param bestEvaluation The current best evaluation
 */
class ExpectedImprovement(bestEvaluation: Double) extends PredictionTransformation {

  // Maximize EI to minimize the evaluation value.
  def isMaxOpt: Boolean = true

  private val standardNormal = new Gaussian(0, 1)

  /**
   * Applies the expected improvement transformation to the model output.
   *
   * @param predictiveMeans Predictive mean output from the model
   * @param predictiveVariances Predictive variance output from the model
   * @return The expected improvement over the current best evaluation
   */
  def apply(
      predictiveMeans: DenseVector[Double],
      predictiveVariances: DenseVector[Double]): DenseVector[Double] = {

    val std = sqrt(predictiveVariances)

    // PBO Eq. 1
    val gamma = - (predictiveMeans - bestEvaluation) / std

    // Eq. 2
    std :* ((gamma :* gamma.map(standardNormal.cdf)) + gamma.map(standardNormal.pdf))
  }
}