package burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners; import burlap.behavior.policy.BoltzmannQPolicy; import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.dpoperator.DifferentiableSoftmaxOperator; import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF; import burlap.behavior.singleagent.planning.Planner; import burlap.debugtools.DPrint; import burlap.mdp.core.action.Action; import burlap.mdp.core.state.State; import burlap.mdp.singleagent.SADomain; import burlap.mdp.singleagent.model.FullModel; import burlap.mdp.singleagent.model.TransitionProb; import burlap.statehashing.HashableState; import burlap.statehashing.HashableStateFactory; import java.util.*; /** * Performs Differentiable Value Iteration using the Boltzmann backup operator and a * {@link burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF}. This class * behaves the same as the normal {@link burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration} * valueFunction except for being in the differentiable value function case. * @author James MacGlashan. */ public class DifferentiableVI extends DifferentiableDP implements Planner { /** * When the maximum change in the value function is smaller than this value, VI will terminate. */ protected double maxDelta; /** * When the number of VI iterations exceeds this value, VI will terminate. */ protected int maxIterations; /** * Indicates whether the reachable states has been computed yet. */ protected boolean foundReachableStates = false; /** * When the reachability analysis to find the state space is performed, a breadth first search-like pass * (spreading over all stochastic transitions) is performed. It can optionally be set so that the * search is pruned at terminal states by setting this value to true. By default, it is false and the full * reachable state space is found */ protected boolean stopReachabilityFromTerminalStates = false; /** * Indicates whether VI has been run or not */ protected boolean hasRunVI = false; protected double boltzBeta; /** * Initializes the valueFunction. * @param domain the domain in which to plan * @param rf the differentiable reward function that will be used * @param gamma the discount factor * @param boltzBeta the scaling factor in the boltzmann distribution used for the state value function. The larger the value, the more deterministic. * @param hashingFactory the state hashing factor to use * @param maxDelta when the maximum change in the value function is smaller than this value, VI will terminate. * @param maxIterations when the number of VI iterations exceeds this value, VI will terminate. */ public DifferentiableVI(SADomain domain, DifferentiableRF rf, double gamma, double boltzBeta, HashableStateFactory hashingFactory, double maxDelta, int maxIterations){ this.DPPInit(domain, gamma, hashingFactory); this.rf = rf; this.maxDelta = maxDelta; this.maxIterations = maxIterations; this.operator = new DifferentiableSoftmaxOperator(boltzBeta); this.boltzBeta = boltzBeta; } /** * Calling this method will force the valueFunction to recompute the reachable states when the {@link #planFromState(State)} method is called next. * This may be useful if the transition dynamics from the last planning call have changed and if planning needs to be restarted as a result. */ public void recomputeReachableStates(){ this.foundReachableStates = false; } /** * Sets whether the state reachability search to generate the state space will be prune the search from terminal states. * The default is not to prune. * @param toggle true if the search should prune the search at terminal states; false if the search should find all reachable states regardless of terminal states. */ public void toggleReachabiltiyTerminalStatePruning(boolean toggle){ this.stopReachabilityFromTerminalStates = toggle; } /** * Plans from the input state and returns a {@link burlap.behavior.policy.BoltzmannQPolicy} following the * Boltzmann parameter used for value Botlzmann value backups in this planner. * @param initialState the initial state of the planning problem * @return a {@link burlap.behavior.policy.BoltzmannQPolicy} */ @Override public BoltzmannQPolicy planFromState(State initialState){ if(!this.valueFunction.containsKey(this.hashingFactory.hashState(initialState))){ this.performReachabilityFrom(initialState); this.runVI(); } return new BoltzmannQPolicy(this, 1./this.boltzBeta); } @Override public void resetSolver(){ super.resetSolver(); this.foundReachableStates = false; this.hasRunVI = false; } /** * Runs VI until the specified termination conditions are met. In general, this method should only be called indirectly through the {@link #planFromState(State)} method. * The {@link #performReachabilityFrom(State)} must have been performed at least once * in the past or a runtime exception will be thrown. The {@link #planFromState(State)} method will automatically call the {@link #performReachabilityFrom(State)} * method first and then this if it hasn't been run. */ public void runVI(){ if(!this.foundReachableStates){ throw new RuntimeException("Cannot run VI until the reachable states have been found. Use the planFromState, performReachabilityFrom, addStateToStateSpace or addStatesToStateSpace methods at least once before calling runVI."); } Set<HashableState> states = valueFunction.keySet(); int i; for(i = 0; i < this.maxIterations; i++){ double delta = 0.; for(HashableState sh : states){ double v = this.value(sh); double newV = this.performBellmanUpdateOn(sh); this.performDPValueGradientUpdateOn(sh); delta = Math.max(Math.abs(newV - v), delta); } if(delta < this.maxDelta){ break; //approximated well enough; stop iterating } } DPrint.cl(this.debugCode, "Passes: " + i); this.hasRunVI = true; } /** * Adds the given state to the state space over which VI iterates. * @param s the state to add */ public void addStateToStateSpace(State s){ HashableState sh = this.hashingFactory.hashState(s); this.valueFunction.put(sh, valueInitializer.value(s)); this.foundReachableStates = true; } /** * Adds a {@link java.util.Collection} of states over which VI will iterate. * @param states the collection of states. */ public void addStatesToStateSpace(Collection<State> states){ for(State s : states){ this.addStateToStateSpace(s); } } /** * This method will find all reachable states that will be used by the {@link #runVI()} method and will cache all the transition dynamics. * This method will not do anything if all reachable states from the input state have been discovered from previous calls to this method. * @param si the source state from which all reachable states will be found * @return true if a reachability analysis had never been performed from this state; false otherwise. */ public boolean performReachabilityFrom(State si){ HashableState sih = this.stateHash(si); DPrint.cl(this.debugCode, "Starting reachability analysis"); //add to the open list LinkedList<HashableState> openList = new LinkedList<HashableState>(); Set <HashableState> openedSet = new HashSet<HashableState>(); openList.offer(sih); openedSet.add(sih); while(!openList.isEmpty()){ HashableState sh = openList.poll(); //skip this if it's already been expanded if(valueFunction.containsKey(sh)){ continue; } //do not need to expand from terminal states if set to prune if(model.terminal(sh.s()) && stopReachabilityFromTerminalStates){ continue; } valueFunction.put(sh, valueInitializer.value(sh.s())); List<Action> actions = this.applicableActions(sh.s()); for(Action a : actions){ List<TransitionProb> tps = ((FullModel)model).transitions(sh.s(), a); for(TransitionProb tp : tps){ HashableState tsh = this.stateHash(tp.eo.op); if(!openedSet.contains(tsh) && !valueFunction.containsKey(tsh)){ openedSet.add(tsh); openList.offer(tsh); } } } } DPrint.cl(this.debugCode, "Finished reachability analysis; # states: " + valueFunction.size()); this.foundReachableStates = true; this.hasRunVI = false; return true; } }