package edu.brown.cs.burlap.tutorials;

import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.policy.PolicyUtils;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.auxiliary.EpisodeSequenceVisualizer;
import burlap.behavior.singleagent.auxiliary.StateReachability;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunction;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.domain.singleagent.gridworld.GridWorldTerminalFunction;
import burlap.domain.singleagent.gridworld.GridWorldVisualizer;
import burlap.domain.singleagent.gridworld.state.GridAgent;
import burlap.domain.singleagent.gridworld.state.GridWorldState;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import burlap.statehashing.simple.SimpleHashableStateFactory;
import burlap.visualizer.Visualizer;

import java.util.*;

/**
 * @author James MacGlashan.
 */
public class VITutorial extends MDPSolver implements Planner, QProvider {

	protected Map<HashableState, Double> valueFunction;
	protected ValueFunction vinit;
	protected int numIterations;


	public VITutorial(SADomain domain, double gamma,
					  HashableStateFactory hashingFactory, ValueFunction vinit, int numIterations){
		this.solverInit(domain, gamma, hashingFactory);
		this.vinit = vinit;
		this.numIterations = numIterations;
		this.valueFunction = new HashMap<HashableState, Double>();
	}

	@Override
	public double value(State s) {
		Double d = this.valueFunction.get(hashingFactory.hashState(s));
		if(d == null){
			return vinit.value(s);
		}
		return d;
	}

	@Override
	public List<QValue> qValues(State s) {
		List<Action> applicableActions = this.applicableActions(s);
		List<QValue> qs = new ArrayList<QValue>(applicableActions.size());
		for(Action a : applicableActions){
			qs.add(new QValue(s, a, this.qValue(s, a)));
		}
		return qs;
	}

	@Override
	public double qValue(State s, Action a) {

		if(this.model.terminal(s)){
			return 0.;
		}

		//what are the possible outcomes?
		List<TransitionProb> tps = ((FullModel)this.model).transitions(s, a);

		//aggregate over each possible outcome
		double q = 0.;
		for(TransitionProb tp : tps){
			//what is reward for this transition?
			double r = tp.eo.r;

			//what is the value for the next state?
			double vp = this.valueFunction.get(this.hashingFactory.hashState(tp.eo.op));

			//add contribution weighted by transition probability and
			//discounting the next state
			q += tp.p * (r + this.gamma * vp);
		}

		return q;
	}

	@Override
	public GreedyQPolicy planFromState(State initialState) {

		HashableState hashedInitialState = this.hashingFactory.hashState(initialState);
		if(this.valueFunction.containsKey(hashedInitialState)){
			return new GreedyQPolicy(this); //already performed planning here!
		}

		//if the state is new, then find all reachable states from it first
		this.performReachabilityFrom(initialState);

		//now perform multiple iterations over the whole state space
		for(int i = 0; i < this.numIterations; i++){
			//iterate over each state
			for(HashableState sh : this.valueFunction.keySet()){
				//update its value using the bellman equation
				this.valueFunction.put(sh, QProvider.Helper.maxQ(this, sh.s()));
			}
		}

		return new GreedyQPolicy(this);

	}

	@Override
	public void resetSolver() {
		this.valueFunction.clear();
	}

	public void performReachabilityFrom(State seedState){

		Set<HashableState> hashedStates = StateReachability.getReachableHashedStates(seedState, this.domain, this.hashingFactory);

		//initialize the value function for all states
		for(HashableState hs : hashedStates){
			if(!this.valueFunction.containsKey(hs)){
				this.valueFunction.put(hs, this.vinit.value(hs.s()));
			}
		}

	}


	public static void main(String [] args){

		GridWorldDomain gwd = new GridWorldDomain(11, 11);
		gwd.setTf(new GridWorldTerminalFunction(10, 10));
		gwd.setMapToFourRooms();

		//only go in intended directon 80% of the time
		gwd.setProbSucceedTransitionDynamics(0.8);

		SADomain domain = gwd.generateDomain();

		//get initial state with agent in 0,0
		State s = new GridWorldState(new GridAgent(0, 0));

		//setup vi with 0.99 discount factor, a value
		//function initialization that initializes all states to value 0, and which will
		//run for 30 iterations over the state space
		VITutorial vi = new VITutorial(domain, 0.99, new SimpleHashableStateFactory(),
				new ConstantValueFunction(0.0), 30);

		//run planning from our initial state
		Policy p = vi.planFromState(s);

		//evaluate the policy with one roll out visualize the trajectory
		Episode ea = PolicyUtils.rollout(p, s, domain.getModel());

		Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
		new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));

	}

}