/*
 * Copyright (c) 2016 Shingo Omura
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of
 * this software and associated documentation files (the "Software"), to deal in
 * the Software without restriction, including without limitation the rights to
 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
 * the Software, and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
 * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

package com.github.everpeace.banditsbook.algorithm.exp3

import breeze.linalg.Vector._
import breeze.linalg._
import breeze.numerics.exp
import breeze.stats.distributions.{Rand, RandBasis}
import breeze.storage.Zero
import com.github.everpeace.banditsbook.algorithm._
import com.github.everpeace.banditsbook.arm.Arm

import scala.collection.immutable.Seq
import scala.reflect.ClassTag

/**
  * see: http://www.cs.nyu.edu/~mohri/pub/bandit.pdf
  */
object Exp3 {

  case class State(γ: Double, weights: Vector[Double], counts: Vector[Int])

  def Algorithm(γ: Double)(implicit zeroReward: Zero[Double], zeroInt: Zero[Int], tag: ClassTag[Double], rand: RandBasis = Rand)
  = {
    require(0< γ && γ <= 1, "γ must be in (0,1]")

    new Algorithm[Double, State] {

      override def initialState(arms: Seq[Arm[Double]]): State = State(
        γ, fill(arms.size)(1.0d), zeros[Int](arms.size)
      )

      override def selectArm(arms: Seq[Arm[Double]], state: State): Int =
        CategoricalDistribution(probs(state.γ, state.weights)).draw()

      override def updateState(arms: Seq[Arm[Double]], state: State, chosen: Int, reward: Double): State = {
        val counts = state.counts
        val weights = state.weights

        val count = counts(chosen) + 1
        counts.update(chosen, count)

        val K = weights.size
        val p = probs(state.γ, weights)
        val x = zeros[Double](K)
        x.update(chosen, reward/p(chosen))
        weights *= exp((state.γ * x) / K.toDouble)

        state.copy(weights = weights, counts = counts)
      }

      private def probs(γ: Double, weights: Vector[Double]): Vector[Double] = {
        val K = weights.size  // #arms
        ((1 - γ) * (weights / sum(weights))) + (γ / K)
      }
    }
  }
}