burlap.statehashing.simple.SimpleHashableStateFactory Java Examples

The following examples show how to use burlap.statehashing.simple.SimpleHashableStateFactory. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: GridWorldDQN.java    From burlap_caffe with Apache License 2.0 6 votes vote down vote up
public GridWorldDQN(String solverFile, double gamma) {

        //create the domain
        gwdg = new GridWorldDomain(11, 11);
        gwdg.setMapToFourRooms();
        rf = new UniformCostRF();
        tf = new SinglePFTF(PropositionalFunction.findPF(gwdg.generatePfs(), GridWorldDomain.PF_AT_LOCATION));
        gwdg.setRf(rf);
        gwdg.setTf(tf);
        domain = gwdg.generateDomain();

        goalCondition = new TFGoalCondition(tf);

        //set up the initial state of the task
        initialState = new GridWorldState(new GridAgent(0, 0), new GridLocation(10, 10, "loc0"));

        //set up the state hashing system for tabular algorithms
        hashingFactory = new SimpleHashableStateFactory();

        //set up the environment for learners algorithms
        env = new SimulatedEnvironment(domain, initialState);

        dqn = new DQN(solverFile, actionSet, new NNGridStateConverter(), gamma);
    }
 
Example #2
Source File: TestHashing.java    From burlap with Apache License 2.0 6 votes vote down vote up
@Test
public void testSimpleHashFactoryLargeState() {
	HashableStateFactory factory = new SimpleHashableStateFactory();
	
	testSimpleHashFactoryLargeState(factory, 10, 100, false);
	testSimpleHashFactoryLargeState(factory, 50, 1000, false);
	testSimpleHashFactoryLargeState(factory, 100, 10000, false);
	testSimpleHashFactoryLargeState(factory, 200,100000, false);
	testSimpleHashFactoryLargeState(factory, 500,100000, false);
	
	testSimpleHashFactoryLargeState(factory, 10, 100, true);
	testSimpleHashFactoryLargeState(factory, 20, 1000, true);
	testSimpleHashFactoryLargeState(factory, 50, 10000, true);
	testSimpleHashFactoryLargeState(factory, 100,100000, true);

}
 
Example #3
Source File: TestHashing.java    From burlap with Apache License 2.0 6 votes vote down vote up
@Test
public void testSimpleHashFactoryIdentifierDependent() {
	SADomain domain = (SADomain)this.gridWorldTest.getDomain();
	State startState = this.gridWorldTest.generateState();
	HashableStateFactory factory = new SimpleHashableStateFactory(false);
	Set<HashableState> hashedStates = this.getReachableHashedStates(startState, domain, factory);
	assert(hashedStates.size() == 104);
	
	Set<HashableState> renamedStates = new HashSet<HashableState>();
	for (HashableState state : hashedStates) {
		State source = state.s();
		State renamed = this.renameObjects((GridWorldState)source.copy());
		HashableState renamedHashed = factory.hashState(renamed);
		renamedStates.add(renamedHashed);
	}
	hashedStates.addAll(renamedStates);
	assert(hashedStates.size() == 208);
}
 
Example #4
Source File: TestHashing.java    From burlap with Apache License 2.0 6 votes vote down vote up
@Test
public void testSimpleHashFactoryIdentifierIndependent() {
	SADomain domain = (SADomain)this.gridWorldTest.getDomain();
	State startState = this.gridWorldTest.generateState();
	HashableStateFactory factory = new SimpleHashableStateFactory();
	Set<HashableState> hashedStates = this.getReachableHashedStates(startState, domain, factory);
	assert(hashedStates.size() == 104);
	
	Set<HashableState> renamedStates = new HashSet<HashableState>();
	for (HashableState state : hashedStates) {
		State source = state.s();
		State renamed = this.renameObjects((GridWorldState)source.copy());
		HashableState renamedHashed = factory.hashState(renamed);
		renamedStates.add(renamedHashed);
	}
	hashedStates.addAll(renamedStates);
	assert(hashedStates.size() == 104);
}
 
Example #5
Source File: BasicBehavior.java    From burlap_examples with MIT License 6 votes vote down vote up
public BasicBehavior(){
		gwdg = new GridWorldDomain(11, 11);
		gwdg.setMapToFourRooms();
		tf = new GridWorldTerminalFunction(10, 10);
		gwdg.setTf(tf);
		goalCondition = new TFGoalCondition(tf);
		domain = gwdg.generateDomain();

		initialState = new GridWorldState(new GridAgent(0, 0), new GridLocation(10, 10, "loc0"));
		hashingFactory = new SimpleHashableStateFactory();

		env = new SimulatedEnvironment(domain, initialState);


//		VisualActionObserver observer = new VisualActionObserver(domain, GridWorldVisualizer.getVisualizer(gwdg.getMap()));
//		observer.initGUI();
//		env.addObservers(observer);
	}
 
Example #6
Source File: ContinuousDomainTutorial.java    From burlap_examples with MIT License 6 votes vote down vote up
public static void IPSS(){

		InvertedPendulum ip = new InvertedPendulum();
		ip.physParams.actionNoise = 0.;
		RewardFunction rf = new InvertedPendulum.InvertedPendulumRewardFunction(Math.PI/8.);
		TerminalFunction tf = new InvertedPendulum.InvertedPendulumTerminalFunction(Math.PI/8.);
		ip.setRf(rf);
		ip.setTf(tf);
		SADomain domain = ip.generateDomain();

		State initialState = new InvertedPendulumState();

		SparseSampling ss = new SparseSampling(domain, 1, new SimpleHashableStateFactory(), 10, 1);
		ss.setForgetPreviousPlanResults(true);
		ss.toggleDebugPrinting(false);
		Policy p = new GreedyQPolicy(ss);

		Episode e = PolicyUtils.rollout(p, initialState, domain.getModel(), 500);
		System.out.println("Num steps: " + e.maxTimeStep());
		Visualizer v = CartPoleVisualizer.getCartPoleVisualizer();
		new EpisodeSequenceVisualizer(v, domain, Arrays.asList(e));

	}
 
Example #7
Source File: CommandReachable.java    From burlapcraft with GNU Lesser General Public License v3.0 6 votes vote down vote up
@Override
public void processCommand(ICommandSender p_71515_1_, String[] p_71515_2_) {

	MinecraftDomainGenerator mdg = new MinecraftDomainGenerator();
	SADomain domain = mdg.generateDomain();

	State in = MinecraftStateGeneratorHelper.getCurrentState(BurlapCraft.currentDungeon);
	List<State> reachable = StateReachability.getReachableStates(in, domain, new SimpleHashableStateFactory());
	for(State s : reachable){
		OOState os = (OOState)s;
		BCAgent a = (BCAgent)os.object(CLASS_AGENT);
		System.out.println(a.x + ", " + a.y + ", " + a.z + ", " + a.rdir + ", "+ a.vdir + ", " + a.selected);
	}
	System.out.println(reachable.size());

}
 
Example #8
Source File: TestHashing.java    From burlap with Apache License 2.0 5 votes vote down vote up
@Test
public void testSimpleHashFactoryLargeStateIdentifierDependent() {
	SADomain domain = (SADomain)this.gridWorldTest.getDomain();
	State startState = this.generateLargeGW(domain, 100);
	HashableStateFactory factory = new SimpleHashableStateFactory(false);
	Set<HashableState> hashedStates = this.getReachableHashedStates(startState, domain, factory);
	int size = hashedStates.size();
	Set<Integer> hashes = new HashSet<Integer>();
	for (HashableState hs : hashedStates) {
		hashes.add(hs.hashCode());
	}
	System.err.println("Hashed states: " + hashedStates.size() + ", hashes: " + hashes.size());
	if (hashedStates.size() != hashes.size()) {
		System.err.println("Hashed states: " + hashedStates.size() + ", hashes: " + hashes.size());
	}
	
	Set<HashableState> renamedStates = new HashSet<HashableState>();
	for (HashableState state : hashedStates) {
		State source = state.s();
		State renamed = this.renameObjects((GridWorldState)source.copy());
		HashableState renamedHashed = factory.hashState(renamed);
		renamedStates.add(renamedHashed);
	}
	hashedStates.addAll(renamedStates);
	assert(hashedStates.size() == size * 2);
	
}
 
Example #9
Source File: TestHashing.java    From burlap with Apache License 2.0 5 votes vote down vote up
@Test
public void testSimpleHashFactory() {
	SADomain domain = (SADomain)this.gridWorldTest.getDomain();
	State startState = this.gridWorldTest.generateState();
	HashableStateFactory factory = new SimpleHashableStateFactory();
	Set<HashableState> hashedStates = this.getReachableHashedStates(startState, domain, factory);
	assert(hashedStates.size() == 104);
}
 
Example #10
Source File: TestPlanning.java    From burlap with Apache License 2.0 5 votes vote down vote up
@Before
public void setup() {
	this.gw = new GridWorldDomain(11, 11);
	this.gw.setMapToFourRooms();
	this.gw.setRf(new UniformCostRF());
	TerminalFunction tf = new SinglePFTF(PropositionalFunction.findPF(gw.generatePfs(), PF_AT_LOCATION));
	this.gw.setTf(tf);
	this.domain = this.gw.generateDomain();
	this.goalCondition = new TFGoalCondition(tf);
	this.hashingFactory = new SimpleHashableStateFactory();
}
 
Example #11
Source File: RewardValueProjection.java    From burlap with Apache License 2.0 5 votes vote down vote up
/**
 * Initializes.
 * @param rf the input {@link RewardFunction} to project for one step.
 * @param projectionType the type of reward projection to use.
 * @param domain the {@link burlap.mdp.core.Domain} in which the {@link RewardFunction} is evaluated.
 */
public RewardValueProjection(RewardFunction rf, RewardProjectionType projectionType, SADomain domain){
	this.rf = rf;
	this.projectionType = projectionType;
	this.domain = domain;
	if(this.projectionType == RewardProjectionType.ONESTEP){
		this.oneStepBellmanPlanner = new SparseSampling(domain, 1., new SimpleHashableStateFactory(), 1, -1);
		this.oneStepBellmanPlanner.setModel(new CustomRewardNoTermModel(domain.getModel(), rf));
		this.oneStepBellmanPlanner.toggleDebugPrinting(false);
		this.oneStepBellmanPlanner.setForgetPreviousPlanResults(true);
	}
}
 
Example #12
Source File: ApprenticeshipLearning.java    From burlap with Apache License 2.0 5 votes vote down vote up
/**
 * Constructor initializes the policy, doesn't compute anything here.
 * @param domain Domain object for which we need to plan
 */
private StationaryRandomDistributionPolicy(SADomain domain) {
	this.stateActionMapping = new HashMap<HashableState, Action>();
	this.stateActionDistributionMapping = new HashMap<HashableState, List<ActionProb>>();
	this.actionTypes = domain.getActionTypes();
	this.rando = new Random();
	this.hashFactory = new SimpleHashableStateFactory(true);
}
 
Example #13
Source File: QLTutorial.java    From burlap_examples with MIT License 5 votes vote down vote up
public static void main(String[] args) {

		GridWorldDomain gwd = new GridWorldDomain(11, 11);
		gwd.setMapToFourRooms();
		gwd.setProbSucceedTransitionDynamics(0.8);
		gwd.setTf(new GridWorldTerminalFunction(10, 10));

		SADomain domain = gwd.generateDomain();

		//get initial state with agent in 0,0
		State s = new GridWorldState(new GridAgent(0, 0));

		//create environment
		SimulatedEnvironment env = new SimulatedEnvironment(domain, s);

		//create Q-learning
		QLTutorial agent = new QLTutorial(domain, 0.99, new SimpleHashableStateFactory(),
				new ConstantValueFunction(), 0.1, 0.1);

		//run Q-learning and store results in a list
		List<Episode> episodes = new ArrayList<Episode>(1000);
		for(int i = 0; i < 1000; i++){
			episodes.add(agent.runLearningEpisode(env));
			env.resetEnvironment();
		}

		Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
		new EpisodeSequenceVisualizer(v, domain, episodes);

	}
 
Example #14
Source File: VITutorial.java    From burlap_examples with MIT License 5 votes vote down vote up
public static void main(String [] args){

		GridWorldDomain gwd = new GridWorldDomain(11, 11);
		gwd.setTf(new GridWorldTerminalFunction(10, 10));
		gwd.setMapToFourRooms();

		//only go in intended directon 80% of the time
		gwd.setProbSucceedTransitionDynamics(0.8);

		SADomain domain = gwd.generateDomain();

		//get initial state with agent in 0,0
		State s = new GridWorldState(new GridAgent(0, 0));

		//setup vi with 0.99 discount factor, a value
		//function initialization that initializes all states to value 0, and which will
		//run for 30 iterations over the state space
		VITutorial vi = new VITutorial(domain, 0.99, new SimpleHashableStateFactory(),
				new ConstantValueFunction(0.0), 30);

		//run planning from our initial state
		Policy p = vi.planFromState(s);

		//evaluate the policy with one roll out visualize the trajectory
		Episode ea = PolicyUtils.rollout(p, s, domain.getModel());

		Visualizer v = GridWorldVisualizer.getVisualizer(gwd.getMap());
		new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));

	}
 
Example #15
Source File: MinecraftSolver.java    From burlapcraft with GNU Lesser General Public License v3.0 4 votes vote down vote up
public static void learn(){

		if(BurlapCraft.currentDungeon != lastDungeon || lastLearningAgent == null){
			MinecraftDomainGenerator mdg = new MinecraftDomainGenerator();
			mdg.setActionWhiteListToNavigationOnly();
			
			lastDomain = mdg.generateDomain();
			lastLearningAgent = new PotentialShapedRMax(lastDomain, 0.99, new SimpleHashableStateFactory(), 0, 1, 0.01, 200);
			lastDungeon = BurlapCraft.currentDungeon;
			
			System.out.println("Starting new RMax");
		}

		MinecraftEnvironment me = new MinecraftEnvironment();
		
		newTimer.start();
		lastLearningAgent.runLearningEpisode(me);
		newTimer.stop();
		
		System.out.println(newTimer.getTotalTime());


	}
 
Example #16
Source File: RandomStartStateGenerator.java    From burlap with Apache License 2.0 4 votes vote down vote up
/**
 * Will discover the reachable states from which to randomly select. Reachable states found using a {@link SimpleHashableStateFactory} with identifier dependence.
 * @param domain the domain from which states will be drawn.
 * @param seedState the seed state from which the reachable states will be found.
 */
public RandomStartStateGenerator(SADomain domain, State seedState) {
	HashableStateFactory hashFactory = new SimpleHashableStateFactory(false);
	this.reachableStates = StateReachability.getReachableStates(seedState, domain, hashFactory);
	this.random = new Random();
}
 
Example #17
Source File: TestBlockDude.java    From burlap with Apache License 2.0 4 votes vote down vote up
public void testDude(State s) {
	TerminalFunction tf = new BlockDudeTF();
	StateConditionTest sc = new TFGoalCondition(tf);

	AStar astar = new AStar(domain, sc, new SimpleHashableStateFactory(), new NullHeuristic());
	astar.toggleDebugPrinting(false);
	astar.planFromState(s);

	Policy p = new SDPlannerPolicy(astar);
	Episode ea = PolicyUtils.rollout(p, s, domain.getModel(), 100);

	State lastState = ea.stateSequence.get(ea.stateSequence.size() - 1);
	Assert.assertEquals(true, tf.isTerminal(lastState));
	Assert.assertEquals(true, sc.satisfies(lastState));
	Assert.assertEquals(-94.0, ea.discountedReturn(1.0), 0.001);

	/*
	BlockDude constructor = new BlockDude();
	Domain d = constructor.generateDomain();

	List<Integer> px = new ArrayList<Integer>();
	List <Integer> ph = new ArrayList<Integer>();

	ph.add(15);
	ph.add(3);
	ph.add(3);
	ph.add(3);
	ph.add(0);
	ph.add(0);
	ph.add(0);
	ph.add(1);
	ph.add(2);
	ph.add(0);
	ph.add(2);
	ph.add(3);
	ph.add(2);
	ph.add(2);
	ph.add(3);
	ph.add(3);
	ph.add(15);
	
	State o = BlockDude.getCleanState(d, px, ph, 6);
	o = BlockDude.setAgent(o, 9, 3, 1, 0);
	o = BlockDude.setExit(o, 1, 0);
	
	o = BlockDude.setBlock(o, 0, 5, 1);
	o = BlockDude.setBlock(o, 1, 6, 1);
	o = BlockDude.setBlock(o, 2, 14, 3);
	o = BlockDude.setBlock(o, 3, 16, 4);
	o = BlockDude.setBlock(o, 4, 17, 4);
	o = BlockDude.setBlock(o, 5, 17, 5);
	
	TerminalFunction tf = new SinglePFTF(d.getPropFunction(BlockDude.PFATEXIT));
	StateConditionTest sc = new SinglePFSCT(d.getPropFunction(BlockDude.PFATEXIT));

	RewardFunction rf = new UniformCostRF();

	AStar astar = new AStar(d, rf, sc, new DiscreteStateHashFactory(), new NullHeuristic());
	astar.toggleDebugPrinting(false);
	astar.planFromState(o);

	Policy p = new SDPlannerPolicy(astar);
	EpisodeAnalysis ea = p.evaluateBehavior(o, rf, tf, 100);

	State lastState = ea.stateSequence.get(ea.stateSequence.size() - 1);
	Assert.assertEquals(true, tf.isTerminal(lastState));
	Assert.assertEquals(true, sc.satisfies(lastState));
	Assert.assertEquals(-94.0, ea.getDiscountedReturn(1.0), 0.001);
	*/
}
 
Example #18
Source File: IRLExample.java    From burlap_examples with MIT License 4 votes vote down vote up
/**
 * Runs MLIRL on the trajectories stored in the "irlDemo" directory and then visualizes the learned reward function.
 */
public void runIRL(String pathToEpisodes){

	//create reward function features to use
	LocationFeatures features = new LocationFeatures(this.domain, 5);

	//create a reward function that is linear with respect to those features and has small random
	//parameter values to start
	LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(features, 5);
	for(int i = 0; i < rf.numParameters(); i++){
		rf.setParameter(i, RandomFactory.getMapped(0).nextDouble()*0.2 - 0.1);
	}

	//load our saved demonstrations from disk
	List<Episode> episodes = Episode.readEpisodes(pathToEpisodes);

	//use either DifferentiableVI or DifferentiableSparseSampling for planning. The latter enables receding horizon IRL,
	//but you will probably want to use a fairly large horizon for this kind of reward function.
	double beta = 10;
	//DifferentiableVI dplanner = new DifferentiableVI(this.domain, rf, 0.99, beta, new SimpleHashableStateFactory(), 0.01, 100);
	DifferentiableSparseSampling dplanner = new DifferentiableSparseSampling(this.domain, rf, 0.99, new SimpleHashableStateFactory(), 10, -1, beta);

	dplanner.toggleDebugPrinting(false);

	//define the IRL problem
	MLIRLRequest request = new MLIRLRequest(domain, dplanner, episodes, rf);
	request.setBoltzmannBeta(beta);

	//run MLIRL on it
	MLIRL irl = new MLIRL(request, 0.1, 0.1, 10);
	irl.performIRL();

	//get all states in the domain so we can visualize the learned reward function for them
	List<State> allStates = StateReachability.getReachableStates(basicState(), this.domain, new SimpleHashableStateFactory());

	//get a standard grid world value function visualizer, but give it StateRewardFunctionValue which returns the
	//reward value received upon reaching each state which will thereby let us render the reward function that is
	//learned rather than the value function for it.
	ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization(
			allStates,
			5,
			5,
			new RewardValueProjection(rf),
			new GreedyQPolicy((QProvider) request.getPlanner())
	);

	gui.initGUI();


}
 
Example #19
Source File: GridGameExample.java    From burlap_examples with MIT License 4 votes vote down vote up
public static void saInterface(){

		GridGame gridGame = new GridGame();
		final OOSGDomain domain = gridGame.generateDomain();

		final HashableStateFactory hashingFactory = new SimpleHashableStateFactory();

		final State s = GridGame.getSimpleGameInitialState();
		JointRewardFunction rf = new GridGame.GGJointRewardFunction(domain, -1, 100, false);
		TerminalFunction tf = new GridGame.GGTerminalFunction(domain);
		SGAgentType at = GridGame.getStandardGridGameAgentType(domain);

		World w = new World(domain, rf, tf, s);

		//single agent Q-learning algorithms which will operate in our stochastic game
		//don't need to specify the domain, because the single agent interface will provide it
		QLearning ql1 = new QLearning(null, 0.99, new SimpleHashableStateFactory(), 0, 0.1);
		QLearning ql2 = new QLearning(null, 0.99, new SimpleHashableStateFactory(), 0, 0.1);

		//create a single-agent interface for each of our learning algorithm instances
		LearningAgentToSGAgentInterface a1 = new LearningAgentToSGAgentInterface(domain, ql1, "agent0", at);
		LearningAgentToSGAgentInterface a2 = new LearningAgentToSGAgentInterface(domain, ql2, "agent1", at);

		w.join(a1);
		w.join(a2);

		//don't have the world print out debug info (comment out if you want to see it!)
		DPrint.toggleCode(w.getDebugId(), false);

		System.out.println("Starting training");
		int ngames = 1000;
		List<GameEpisode> gas = new ArrayList<GameEpisode>(ngames);
		for(int i = 0; i < ngames; i++){
			GameEpisode ga = w.runGame();
			gas.add(ga);
			if(i % 10 == 0){
				System.out.println("Game: " + i + ": " + ga.maxTimeStep());
			}
		}

		System.out.println("Finished training");


		Visualizer v = GGVisualizer.getVisualizer(9, 9);
		new GameSequenceVisualizer(v, domain, gas);


	}
 
Example #20
Source File: GridGameExample.java    From burlap_examples with MIT License 4 votes vote down vote up
public static void QLCoCoTest(){

		GridGame gridGame = new GridGame();
		final OOSGDomain domain = gridGame.generateDomain();

		final HashableStateFactory hashingFactory = new SimpleHashableStateFactory();

		final State s = GridGame.getPrisonersDilemmaInitialState();
		JointRewardFunction rf = new GridGame.GGJointRewardFunction(domain, -1, 100, false);
		TerminalFunction tf = new GridGame.GGTerminalFunction(domain);
		SGAgentType at = GridGame.getStandardGridGameAgentType(domain);

		World w = new World(domain, rf, tf, s);

		final double discount = 0.95;
		final double learningRate = 0.1;
		final double defaultQ = 100;

		MultiAgentQLearning a0 = new MultiAgentQLearning(domain, discount, learningRate, hashingFactory, defaultQ, new CoCoQ(), true, "agent0", at);
		MultiAgentQLearning a1 = new MultiAgentQLearning(domain, discount, learningRate, hashingFactory, defaultQ, new CoCoQ(), true, "agent1", at);

		w.join(a0);
		w.join(a1);


		//don't have the world print out debug info (comment out if you want to see it!)
		DPrint.toggleCode(w.getDebugId(), false);

		System.out.println("Starting training");
		int ngames = 1000;
		List<GameEpisode> games = new ArrayList<GameEpisode>();
		for(int i = 0; i < ngames; i++){
			GameEpisode ga = w.runGame();
			games.add(ga);
			if(i % 10 == 0){
				System.out.println("Game: " + i + ": " + ga.maxTimeStep());
			}
		}

		System.out.println("Finished training");


		Visualizer v = GGVisualizer.getVisualizer(9, 9);
		new GameSequenceVisualizer(v, domain, games);

	}
 
Example #21
Source File: GridGameExample.java    From burlap_examples with MIT License 4 votes vote down vote up
public static void VICoCoTest(){

		//grid game domain
		GridGame gridGame = new GridGame();
		final OOSGDomain domain = gridGame.generateDomain();

		final HashableStateFactory hashingFactory = new SimpleHashableStateFactory();

		//run the grid game version of prisoner's dilemma
		final State s = GridGame.getPrisonersDilemmaInitialState();

		//define joint reward function and termination conditions for this game
		JointRewardFunction rf = new GridGame.GGJointRewardFunction(domain, -1, 100, false);
		TerminalFunction tf = new GridGame.GGTerminalFunction(domain);

		//both agents are standard: access to all actions
		SGAgentType at = GridGame.getStandardGridGameAgentType(domain);

		//create our multi-agent planner
		MAValueIteration vi = new MAValueIteration(domain, rf, tf, 0.99, hashingFactory, 0., new CoCoQ(), 0.00015, 50);

		//instantiate a world in which our agents will play
		World w = new World(domain, rf, tf, s);


		//create a greedy joint policy from our planner's Q-values
		EGreedyMaxWellfare jp0 = new EGreedyMaxWellfare(0.);
		jp0.setBreakTiesRandomly(false); //don't break ties randomly

		//create agents that follows their end of the computed the joint policy
		MultiAgentDPPlanningAgent a0 = new MultiAgentDPPlanningAgent(domain, vi, new PolicyFromJointPolicy(0, jp0), "agent0", at);
		MultiAgentDPPlanningAgent a1 = new MultiAgentDPPlanningAgent(domain, vi, new PolicyFromJointPolicy(1, jp0), "agent1", at);

		w.join(a0);
		w.join(a1);

		//run some games of the agents playing that policy
		GameEpisode ga = null;
		for(int i = 0; i < 3; i++){
			ga = w.runGame();
		}

		//visualize results
		Visualizer v = GGVisualizer.getVisualizer(9, 9);
		new GameSequenceVisualizer(v, domain, Arrays.asList(ga));


	}
 
Example #22
Source File: MinecraftSolver.java    From burlapcraft with GNU Lesser General Public License v3.0 3 votes vote down vote up
public static void stocasticPlan(double gamma){

		MinecraftDomainGenerator simdg = new MinecraftDomainGenerator();
		
		SADomain domain = simdg.generateDomain();

		State initialState = MinecraftStateGeneratorHelper.getCurrentState(BurlapCraft.currentDungeon);
		
		Planner planner = new ValueIteration(domain, gamma, new SimpleHashableStateFactory(false), 0.001, 1000);
		
		Policy p = planner.planFromState(initialState);
		
		MinecraftEnvironment me = new MinecraftEnvironment();
		PolicyUtils.rollout(p, me);
	}
 
Example #23
Source File: FittedVI.java    From burlap with Apache License 2.0 3 votes vote down vote up
/**
 * Initializes. Note that you will need to set the state samples to use for planning with the {@link #setSamples(java.util.List)} method before
 * calling {@link #planFromState(State)}, {@link #runIteration()}, or {@link #runVI()}, otherwise a runtime exception
 * will be thrown.
 * @param domain the domain in which to plan
 * @param gamma the discount factor
 * @param valueFunctionTrainer the supervised learning algorithm to use for each value iteration
 * @param transitionSamples the number of transition samples to use when computing the bellman operator; set to -1 if you want to use the full transition dynamics without sampling.
 * @param maxDelta the maximum change in the value function that will cause planning to terminate.
 * @param maxIterations the maximum number of iterations to run.
 */
public FittedVI(SADomain domain, double gamma, SupervisedVFA valueFunctionTrainer, int transitionSamples, double maxDelta, int maxIterations){
	this.solverInit(domain, gamma, new SimpleHashableStateFactory());
	this.valueFunctionTrainer = valueFunctionTrainer;
	this.transitionSamples = transitionSamples;
	this.maxDelta = maxDelta;
	this.maxIterations = maxIterations;
	this.debugCode = 5263;

	this.valueFunction = this.vinit;
}
 
Example #24
Source File: FittedVI.java    From burlap with Apache License 2.0 3 votes vote down vote up
/**
 * Initializes. Note that you will need to set the state samples to use for planning with the {@link #setSamples(java.util.List)} method before
 * calling {@link #planFromState(State)}, {@link #runIteration()}, or {@link #runVI()}, otherwise a runtime exception
 * will be thrown.
 * @param domain the domain in which to plan
 * @param gamma the discount factor
 * @param valueFunctionTrainer the supervised learning algorithm to use for each value iteration
 * @param samples the set of state samples to use for planning.
 * @param transitionSamples the number of transition samples to use when computing the bellman operator; set to -1 if you want to use the full transition dynamics without sampling.
 * @param maxDelta the maximum change in the value function that will cause planning to terminate.
 * @param maxIterations the maximum number of iterations to run.
 */
public FittedVI(SADomain domain, double gamma, SupervisedVFA valueFunctionTrainer, List<State> samples, int transitionSamples, double maxDelta, int maxIterations){
	this.solverInit(domain, gamma, new SimpleHashableStateFactory());
	this.valueFunctionTrainer = valueFunctionTrainer;
	this.samples = samples;
	this.transitionSamples = transitionSamples;
	this.maxDelta = maxDelta;
	this.maxIterations = maxIterations;
	this.debugCode = 5263;

	this.valueFunction = this.vinit;
}
 
Example #25
Source File: TigerDomain.java    From burlap with Apache License 2.0 3 votes vote down vote up
@Override
public Domain generateDomain() {
	
	PODomain domain = new PODomain();



	domain.addActionType(new UniversalActionType(ACTION_LEFT))
			.addActionType(new UniversalActionType(ACTION_RIGHT))
			.addActionType(new UniversalActionType(ACTION_LISTEN));

	if(this.includeDoNothing){
		domain.addActionType(new UniversalActionType(ACTION_DO_NOTHING));
	}


	ObservationFunction of = new TigerObservations(this.listenAccuracy, this.includeDoNothing);
	domain.setObservationFunction(of);

	TigerModel model = new TigerModel(correctDoorReward, wrongDoorReward, listenReward, nothingReward);
	domain.setModel(model);
	
	StateEnumerator senum = new StateEnumerator(domain, new SimpleHashableStateFactory());
	senum.getEnumeratedID(new TigerState(VAL_LEFT));
	senum.getEnumeratedID(new TigerState(VAL_RIGHT));
	
	domain.setStateEnumerator(senum);
	
	return domain;
}
 
Example #26
Source File: GridGameExample.java    From burlap_examples with MIT License 3 votes vote down vote up
public static void VICorrelatedTest(){

		GridGame gridGame = new GridGame();
		final OOSGDomain domain = gridGame.generateDomain();

		final HashableStateFactory hashingFactory = new SimpleHashableStateFactory();

		final State s = GridGame.getPrisonersDilemmaInitialState();

		JointRewardFunction rf = new GridGame.GGJointRewardFunction(domain, -1, 100, false);
		TerminalFunction tf = new GridGame.GGTerminalFunction(domain);

		SGAgentType at = GridGame.getStandardGridGameAgentType(domain);
		MAValueIteration vi = new MAValueIteration(domain, rf, tf, 0.99, hashingFactory, 0., new CorrelatedQ(CorrelatedEquilibriumSolver.CorrelatedEquilibriumObjective.UTILITARIAN), 0.00015, 50);

		World w = new World(domain, rf, tf, s);


		//for correlated Q, use a correlated equilibrium policy joint policy
		ECorrelatedQJointPolicy jp0 = new ECorrelatedQJointPolicy(CorrelatedEquilibriumSolver.CorrelatedEquilibriumObjective.UTILITARIAN, 0.);


		MultiAgentDPPlanningAgent a0 = new MultiAgentDPPlanningAgent(domain, vi, new PolicyFromJointPolicy(0, jp0, true), "agent0", at);
		MultiAgentDPPlanningAgent a1 = new MultiAgentDPPlanningAgent(domain, vi, new PolicyFromJointPolicy(1, jp0, true), "agent1", at);

		w.join(a0);
		w.join(a1);

		GameEpisode ga = null;
		List<GameEpisode> games = new ArrayList<GameEpisode>();
		for(int i = 0; i < 10; i++){
			ga = w.runGame();
			games.add(ga);
		}

		Visualizer v = GGVisualizer.getVisualizer(9, 9);
		new GameSequenceVisualizer(v, domain, games);


	}
 
Example #27
Source File: PlotTest.java    From burlap_examples with MIT License 2 votes vote down vote up
public static void main(String [] args){

		GridWorldDomain gw = new GridWorldDomain(11,11); //11x11 grid world
		gw.setMapToFourRooms(); //four rooms layout
		gw.setProbSucceedTransitionDynamics(0.8); //stochastic transitions with 0.8 success rate

		//ends when the agent reaches a location
		final TerminalFunction tf = new SinglePFTF(
				PropositionalFunction.findPF(gw.generatePfs(), GridWorldDomain.PF_AT_LOCATION));

		//reward function definition
		final RewardFunction rf = new GoalBasedRF(new TFGoalCondition(tf), 5., -0.1);

		gw.setTf(tf);
		gw.setRf(rf);


		final OOSADomain domain = gw.generateDomain(); //generate the grid world domain

		//setup initial state
		GridWorldState s = new GridWorldState(new GridAgent(0, 0), new GridLocation(10, 10, "loc0"));



		//initial state generator
		final ConstantStateGenerator sg = new ConstantStateGenerator(s);


		//set up the state hashing system for looking up states
		final SimpleHashableStateFactory hashingFactory = new SimpleHashableStateFactory();


		/**
		 * Create factory for Q-learning agent
		 */
		LearningAgentFactory qLearningFactory = new LearningAgentFactory() {

			public String getAgentName() {
				return "Q-learning";
			}

			public LearningAgent generateAgent() {
				return new QLearning(domain, 0.99, hashingFactory, 0.3, 0.1);
			}
		};

		//define learning environment
		SimulatedEnvironment env = new SimulatedEnvironment(domain, sg);

		//define experiment
		LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter(env,
				10, 100, qLearningFactory);

		exp.setUpPlottingConfiguration(500, 250, 2, 1000, TrialMode.MOST_RECENT_AND_AVERAGE,
				PerformanceMetric.CUMULATIVE_STEPS_PER_EPISODE, PerformanceMetric.AVERAGE_EPISODE_REWARD);


		//start experiment
		exp.startExperiment();


	}