package burlap.behavior.singleagent.options.model; import burlap.behavior.policy.support.ActionProb; import burlap.behavior.singleagent.options.Option; import burlap.datastructures.HashedAggregator; import burlap.mdp.core.action.Action; import burlap.mdp.core.state.State; import burlap.mdp.singleagent.environment.EnvironmentOutcome; import burlap.mdp.singleagent.environment.SimulatedEnvironment; import burlap.mdp.singleagent.model.FullModel; import burlap.mdp.singleagent.model.SampleModel; import burlap.mdp.singleagent.model.TransitionProb; import burlap.statehashing.HashableState; import burlap.statehashing.HashableStateFactory; import java.util.*; /** * A model that can compute a Markov option's transition model, and cache it, from a source {@link SampleModel}. A {@link FullModel} is * required for the {@link #transitions(State, Action)} method. Note that the transitions model for an option * is a multi-time model, which means the state transition probabilities factor in the discount factor. That is, * P(s' | s, o) = \sum_k^\ifnty p(s', k | s, o) \gamma^k, where p(s', k | s, o) is the probability that the * agent will terminate in state s' after k steps, given that option o was initiated in state s. * <p> * The computation of the transition model can be quite * expensive (particularly for stochastic domains) and ideally, you should consider a custom implementation of your option model. The computation of * the model proceeds by running a BFS-like algorithm from the input state following the option policy * to possible option (or environment) termination states. The BFS expansion will stop when a minimum threshold * of the probability mass of all possible trajectories following the policy is computed (by default 0.999). However, * you can shrink the probability threshold using the method {@link #setMinProb(double)} to decrease computation time. * When you decrease the probability threshold, * the compute probabilities are normalized by the amount of the trajectory probability mass computed, given * an estimated option transition model. * <p> * If you need a model for non-Markov options (e.g., a {@link burlap.behavior.singleagent.options.MacroAction}), use * the {@link BFSNonMarkovOptionModel} model, which using slightly more memory overhead in the computation to maintain * the fully trajectory history. * * @author James MacGlashan. */ public class BFSMarkovOptionModel implements FullModel{ protected SampleModel model; protected double discount; protected HashableStateFactory hashingFactory; protected Map<Option, CachedModel> cachedModels = new HashMap<Option, CachedModel>(); protected Set<HashableState> srcTerminateStates = new HashSet<HashableState>(); protected double minProb = 0.999; protected boolean requireMarkov = true; public BFSMarkovOptionModel(SampleModel model, double discount, HashableStateFactory hashingFactory) { this.model = model; this.discount = discount; this.hashingFactory = hashingFactory; } public void setMinProb(double minProb) { this.minProb = minProb; } @Override public List<TransitionProb> transitions(State s, Action a) { if(!(model instanceof FullModel)){ throw new RuntimeException("Cannot compute option transition function probability distribution, because the underlying state model is" + "not a FullModel"); } FullModel fmodel = (FullModel)model; if(!(a instanceof Option)){ return fmodel.transitions(s, a); } Option o = (Option)a; if(!o.markov() && requireMarkov){ throw new RuntimeException("DerivedOptionMarkovModel can only compute transition function probability distribution for Markov options, but the input Option is not Markov"); } CachedModel cmodel = this.getOrCreateModel(o); List<TransitionProb> result = cmodel.cachedExpectations.get(hashingFactory.hashState(s)); if(result != null){ return result; } HashedAggregator <HashableState> possibleTerminations = new HashedAggregator<HashableState>(); double [] expectedReturn = new double[]{0.}; double sumProb = this.computeTransitions(s, o, possibleTerminations, expectedReturn); double r = expectedReturn[0]; List<TransitionProb> transitions = new ArrayList<TransitionProb>(possibleTerminations.size()); for(Map.Entry<HashableState, Double> e : possibleTerminations.entrySet()){ EnvironmentOutcome eo = new EnvironmentOutcome(s, a, e.getKey().s(), r, srcTerminateStates.contains(e.getKey())); double p = e.getValue(); p /= sumProb; TransitionProb tp = new TransitionProb(p, eo); transitions.add(tp); } return transitions; } @Override public EnvironmentOutcome sample(State s, Action a) { if(!(a instanceof Option)){ return model.sample(s, a); } Option o = (Option)a; SimulatedEnvironment env = new SimulatedEnvironment(model, s); return o.control(env, discount); } @Override public boolean terminal(State s) { return this.model.terminal(s); } protected CachedModel getOrCreateModel(Option o){ CachedModel model = this.cachedModels.get(o); if(model != null){ return model; } model = new CachedModel(); this.cachedModels.put(o, model); return model; } protected double computeTransitions(State s, Option o, HashedAggregator<HashableState> possibleTerminations, double [] expectedReturn){ double sumTermProb = 0.; LinkedList<OptionScanNode> openList = new LinkedList<OptionScanNode>(); OptionScanNode inode = new OptionScanNode(s); openList.addLast(inode); while(openList.size() > 0 && sumTermProb < this.minProb){ OptionScanNode src = openList.poll(); double probTerm = 0.0; //can never terminate in initiation state if(src.nSteps > 0){ probTerm = o.probabilityOfTermination(src.s, null); } if(this.model.terminal(src.s)){ probTerm = 1.; } double probContinue = 1.-probTerm; double stackedDiscount = Math.pow(this.discount, src.nSteps); //handle possible termination if(probTerm > 0.){ double probOfDiscountedTrajectory = src.probability*stackedDiscount*probTerm; possibleTerminations.add(hashingFactory.hashState(src.s), probOfDiscountedTrajectory); expectedReturn[0] += src.cumulativeDiscountedReward*src.probability*probTerm; sumTermProb += src.probability; } //handle continuation if(probContinue > 0.){ //handle option policy selection List <ActionProb> actionSelction = o.policyDistribution(src.s, null); for(ActionProb ap : actionSelction){ //now get possible outcomes of each action List <TransitionProb> transitions = ((FullModel)model).transitions(src.s, ap.ga); for(TransitionProb tp : transitions){ double totalTransP = ap.pSelection * tp.p * probContinue; double r = stackedDiscount * tp.eo.r; if(tp.eo.terminated){ srcTerminateStates.add(hashingFactory.hashState(tp.eo.op)); } OptionScanNode next = new OptionScanNode(src, tp.eo.op, totalTransP, r); openList.addLast(next); } } } } return sumTermProb; } public static class CachedModel{ /** * The cached transition probabilities from each initiation state */ protected Map<HashableState, List <TransitionProb>> cachedExpectations = new HashMap<HashableState, List<TransitionProb>>(); } public static class OptionScanNode{ /** * the state this search node wraps */ public State s; /** * the *un*-discounted probability of reaching this search node */ public double probability; /** * The cumulative discounted reward received reaching this node. */ public double cumulativeDiscountedReward; /** * The number of steps taken to reach this node. */ public int nSteps; public OptionScanNode() { } public OptionScanNode(State s) { this.s = s; this.probability = 1.; this.cumulativeDiscountedReward = 0.; this.nSteps = 0; } /** * Initializes. * @param src a source parent node from which this node was generated * @param s the state this search node wraps * @param transProb the transition probability of reaching this node from the source node * @param discountedR the discounted reward received from reaching this node from the source node. */ public OptionScanNode(OptionScanNode src, State s, double transProb, double discountedR){ this.s = s; this.probability = src.probability*transProb; this.cumulativeDiscountedReward = src.cumulativeDiscountedReward + discountedR; this.nSteps = src.nSteps+1; } } }