Java Code Examples for burlap.mdp.singleagent.SADomain#getModel()

The following examples show how to use burlap.mdp.singleagent.SADomain#getModel() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DifferentiableSparseSampling.java    From burlap with Apache License 2.0 6 votes vote down vote up
/**
 * Initializes. The model of this planner will automatically be set to a {@link CustomRewardModel} using the provided reward function.
 * @param domain the problem domain
 * @param rf the differentiable reward function
 * @param gamma the discount factor
 * @param hashingFactory the hashing factory used to compare state equality
 * @param h the planning horizon
 * @param c how many samples from the transition dynamics to use. Set to -1 to use the full (unsampled) transition dynamics.
 * @param boltzBeta the Boltzmann beta parameter for the differentiable Boltzmann (softmax) backup equation. The larger the value the more deterministic, the closer to 1 the softer.
 */
public DifferentiableSparseSampling(SADomain domain, DifferentiableRF rf, double gamma, HashableStateFactory hashingFactory, int h, int c, double boltzBeta){
	this.solverInit(domain, gamma, hashingFactory);
	this.h = h;
	this.c = c;
	this.rf = rf;
	this.boltzBeta = boltzBeta;
	this.nodesByHeight = new HashMap<SparseSampling.HashedHeightState, DiffStateNode>();
	this.rootLevelQValues = new HashMap<HashableState, DifferentiableSparseSampling.QAndQGradient>();
	this.rfDim = rf.numParameters();

	this.vinit = new VanillaDiffVinit(new ConstantValueFunction(), rf);

	this.model = new CustomRewardModel(domain.getModel(), rf);

	this.operator = new DifferentiableSoftmaxOperator(boltzBeta);

	this.debugCode = 6368290;
}
 
Example 2
Source File: SimulatedEnvironment.java    From burlap with Apache License 2.0 5 votes vote down vote up
public SimulatedEnvironment(SADomain domain, State initialState) {

		this.stateGenerator = new ConstantStateGenerator(initialState);
		this.curState = initialState;
		if(domain.getModel() == null){
			throw new RuntimeException("SimulatedEnvironment requires a Domain with a model, but the input domain does not have one.");
		}
		this.model = domain.getModel();
	}
 
Example 3
Source File: SimulatedEnvironment.java    From burlap with Apache License 2.0 5 votes vote down vote up
public SimulatedEnvironment(SADomain domain, StateGenerator stateGenerator) {
	this.stateGenerator = stateGenerator;
	this.curState = stateGenerator.generateState();
	if(domain.getModel() == null){
		throw new RuntimeException("SimulatedEnvironment requires a Domain with a model, but the input domain does not have one.");
	}
	this.model = domain.getModel();
}
 
Example 4
Source File: RLGlueEnvironment.java    From burlap with Apache License 2.0 5 votes vote down vote up
/**
 * Constructs with all the BURLAP information necessary for generating an RLGlue Environment.
 * @param domain the BURLAP domain
 * @param stateGenerator a generated for generating states at the start of each episode.
 * @param stateFlattener used to flatten states into a numeric representation
 * @param valueRanges the value ranges of the flattened vector state
 * @param rewardRange the reward function value range
 * @param isEpisodic whether the task is episodic or continuing
 * @param discount the discount factor to use for the task
 */
public RLGlueEnvironment(SADomain domain, StateGenerator stateGenerator, DenseStateFeatures stateFlattener,
						 DoubleRange[] valueRanges,
						 DoubleRange rewardRange, boolean isEpisodic, double discount){

	if(domain.getModel() == null){
		throw new RuntimeException("RLGlueEnvironment requires a BURLAP domain with a SampleModel, but the domain does not provide one.");
	}

	this.domain = domain;
	this.stateGenerator = stateGenerator;
	this.stateFlattener = stateFlattener;
	this.valueRanges = valueRanges;
	this.rewardRange = rewardRange;
	this.isEpisodic = isEpisodic;
	this.discount = discount;
	
	State exampleState = this.stateGenerator.generateState();
	int actionInd = 0;
	for(ActionType a : this.domain.getActionTypes()){
		List<burlap.mdp.core.action.Action> gas = a.allApplicableActions(exampleState);
		for(burlap.mdp.core.action.Action ga : gas){
			this.actionMap.put(actionInd, ga);
			actionInd++;
		}
	}
	
	//set this to be the first state returned
	this.curState = exampleState;
	
	
}
 
Example 5
Source File: StateReachability.java    From burlap with Apache License 2.0 4 votes vote down vote up
/**
 * Returns the set of {@link State} objects that are reachable from a source state.
 * @param from the source state
 * @param inDomain the domain of the state
 * @param usingHashFactory the state hashing factory to use for indexing states and testing equality.
 * @return the set of {@link State} objects that are reachable from a source state.
 */
public static Set <HashableState> getReachableHashedStates(State from, SADomain inDomain, HashableStateFactory usingHashFactory){

	if(!(inDomain.getModel() instanceof FullModel)){
		throw new RuntimeException( "State reachablity requires a domain with a FullModel, but one is not provided");
	}

	FullModel model = (FullModel)inDomain.getModel();

	Set<HashableState> hashedStates = new HashSet<HashableState>();
	HashableState shi = usingHashFactory.hashState(from);
	List <ActionType> actionTypes = inDomain.getActionTypes();
	int nGenerated = 0;
	
	LinkedList <HashableState> openList = new LinkedList<HashableState>();
	openList.offer(shi);
	hashedStates.add(shi);
	long firstTime = System.currentTimeMillis();
	long lastTime = firstTime;
	while(!openList.isEmpty()){
		HashableState sh = openList.poll();

		
		List<Action> gas = ActionUtils.allApplicableActionsForTypes(actionTypes, sh.s());
		for(Action ga : gas){
			List <TransitionProb> tps = model.transitions(sh.s(), ga);
			nGenerated += tps.size();
			for(TransitionProb tp : tps){
				HashableState nsh = usingHashFactory.hashState(tp.eo.op);
				
				if (hashedStates.add(nsh) && !tp.eo.terminated) {
					openList.offer(nsh);
				}
			}
		}
		
		long currentTime = System.currentTimeMillis();
		if (currentTime - 1000 >= lastTime) {
			DPrint.cl(debugID, "Num generated: " + (nGenerated) + " Unique: " + (hashedStates.size()) + 
					" time: " + ((double)currentTime - firstTime)/1000.0);				
			lastTime = currentTime;
		}
	}
	
	DPrint.cl(debugID, "Num generated: " + nGenerated + "; num unique: " + hashedStates.size());
	
	return hashedStates;
}
 
Example 6
Source File: StateReachability.java    From burlap with Apache License 2.0 4 votes vote down vote up
/**
 * Finds the set of states ({@link burlap.statehashing.HashableState}) that are reachable under a policy from a source state. Reachability under a source policy means
 * that the space of actions considered are those that have non-zero probability of being selected by the
 * policy and all possible outcomes of those states are considered.
 * @param domain the domain containing the model to use for evaluating reachable states
 * @param p the policy that must be followed
 * @param from the source {@link State} from which the policy would be initiated.
 * @param usingHashFactory the {@link burlap.statehashing.HashableStateFactory} used to hash states and test equality.
 * @return a {@link java.util.Set} of {@link burlap.statehashing.HashableState} objects that could be reached.
 */
public static Set<HashableState> getPolicyReachableHashedStates(SADomain domain, EnumerablePolicy p, State from, HashableStateFactory usingHashFactory){

	if(!(domain.getModel() instanceof FullModel)){
		throw new RuntimeException( "State reachablity requires a domain with a FullModel, but one is not provided");
	}

	FullModel model = (FullModel)domain.getModel();

	Set<HashableState> hashedStates = new HashSet<HashableState>();
	HashableState shi = usingHashFactory.hashState(from);
	int nGenerated = 0;

	LinkedList <HashableState> openList = new LinkedList<HashableState>();
	openList.offer(shi);
	hashedStates.add(shi);

	MyTimer timer = new MyTimer(true);
	while(!openList.isEmpty()){
		HashableState sh = openList.poll();


		List<ActionProb> policyActions = p.policyDistribution(sh.s());
		for(ActionProb ap : policyActions){
			if(ap.pSelection > 0){
				List <TransitionProb> tps = model.transitions(sh.s(), ap.ga);
				nGenerated += tps.size();
				for(TransitionProb tp : tps){
					HashableState nsh = usingHashFactory.hashState(tp.eo.op);

					if (hashedStates.add(nsh) && !tp.eo.terminated) {
						openList.offer(nsh);
					}
				}
			}
		}

		if(timer.peekAtTime() > 1){
			timer.stop();
			DPrint.cl(debugID, "Num generated: " + (nGenerated) + " Unique: " + (hashedStates.size()) +
					" time: " + timer.getTime());
			timer.start();
		}
	}

	timer.stop();

	DPrint.cl(debugID, "Num generated: " + nGenerated + "; num unique: " + hashedStates.size());

	return hashedStates;
}
 
Example 7
Source File: SimulatedEnvironment.java    From burlap with Apache License 2.0 4 votes vote down vote up
public SimulatedEnvironment(SADomain domain){
	if(domain.getModel() == null){
		throw new RuntimeException("SimulatedEnvironment requires a Domain with a model, but the input domain does not have one.");
	}
	this.model = domain.getModel();
}