burlap.behavior.singleagent.auxiliary.performance.TrialMode Java Examples

The following examples show how to use burlap.behavior.singleagent.auxiliary.performance.TrialMode. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: Main.java    From cs7641-assignment4 with MIT License 6 votes vote down vote up
/**
 * Runs a learning experiment and shows some cool charts. Apparently, this is only useful for
 * Q-Learning, so I only call this method when Q-Learning is selected and the appropriate flag
 * is enabled.
 */
private static void learningExperimenter(Problem problem, LearningAgent agent, SimulatedEnvironment simulatedEnvironment) {
	LearningAlgorithmExperimenter experimenter = new LearningAlgorithmExperimenter(simulatedEnvironment, 10, problem.getNumberOfIterations(Algorithm.QLearning), new LearningAgentFactory() {

		public String getAgentName() {
			return Algorithm.QLearning.getTitle();
		}

		public LearningAgent generateAgent() {
			return agent;
		}
	});

	/*
	 * Try different PerformanceMetric values below to display different charts.
	 */
	experimenter.setUpPlottingConfiguration(500, 250, 2, 1000, TrialMode.MOST_RECENT_AND_AVERAGE, PerformanceMetric.CUMULATIVE_STEPS_PER_EPISODE, PerformanceMetric.AVERAGE_EPISODE_REWARD);
	experimenter.startExperiment();
}
 
Example #2
Source File: BasicBehavior.java    From burlap_examples with MIT License 4 votes vote down vote up
public void experimentAndPlotter(){

		//different reward function for more structured performance plots
		((FactoredModel)domain.getModel()).setRf(new GoalBasedRF(this.goalCondition, 5.0, -0.1));

		/**
		 * Create factories for Q-learning agent and SARSA agent to compare
		 */
		LearningAgentFactory qLearningFactory = new LearningAgentFactory() {

			public String getAgentName() {
				return "Q-Learning";
			}


			public LearningAgent generateAgent() {
				return new QLearning(domain, 0.99, hashingFactory, 0.3, 0.1);
			}
		};

		LearningAgentFactory sarsaLearningFactory = new LearningAgentFactory() {

			public String getAgentName() {
				return "SARSA";
			}


			public LearningAgent generateAgent() {
				return new SarsaLam(domain, 0.99, hashingFactory, 0.0, 0.1, 1.);
			}
		};

		LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter(env, 10, 100, qLearningFactory, sarsaLearningFactory);
		exp.setUpPlottingConfiguration(500, 250, 2, 1000,
				TrialMode.MOST_RECENT_AND_AVERAGE,
				PerformanceMetric.CUMULATIVE_STEPS_PER_EPISODE,
				PerformanceMetric.AVERAGE_EPISODE_REWARD);

		exp.startExperiment();
		exp.writeStepAndEpisodeDataToCSV("expData");

	}
 
Example #3
Source File: MultiAgentPerformancePlotter.java    From burlap with Apache License 2.0 4 votes vote down vote up
/**
 * Initializes
 * @param tf the terminal function that will be used for detecting the end of episdoes
 * @param chartWidth the width of a cart
 * @param chartHeight the height of a chart
 * @param columns the number of columns of charts
 * @param maxWindowHeight the maximum window height until a scroll bar will be added
 * @param trialMode the kinds of trail data that will be displayed
 * @param metrics which metrics will be plotted.
 */
public MultiAgentPerformancePlotter(TerminalFunction tf, int chartWidth, int chartHeight, int columns, int maxWindowHeight,
							TrialMode trialMode, PerformanceMetric...metrics){


	this.tf = tf;
	
	colCSR = new XYSeriesCollection();
	colCER = new XYSeriesCollection();
	colAER = new XYSeriesCollection();
	colMER = new XYSeriesCollection();
	colCSE = new XYSeriesCollection();
	colSE = new XYSeriesCollection();
	
	colCSRAvg = new YIntervalSeriesCollection();
	colCERAvg = new YIntervalSeriesCollection();
	colAERAvg = new YIntervalSeriesCollection();
	colMERAvg = new YIntervalSeriesCollection();
	colCSEAvg = new YIntervalSeriesCollection();
	colSEAvg = new YIntervalSeriesCollection();
	
	
	if(metrics.length == 0){
		metricsSet.add(PerformanceMetric.CUMULATIVE_REWARD_PER_STEP);
		
		metrics = new PerformanceMetric[]{PerformanceMetric.CUMULATIVE_REWARD_PER_STEP};
	}
	
	this.trialMode = trialMode;
	
	Container plotContainer = new Container();
	plotContainer.setLayout(new GridBagLayout());
	GridBagConstraints c = new GridBagConstraints();
	c.gridx = 0;
	c.gridy = 0;
	c.insets = new Insets(0, 0, 10, 10);
       
       for(PerformanceMetric m : metrics){
       	
       	this.metricsSet.add(m);
       	
       	if(m == PerformanceMetric.CUMULATIVE_REWARD_PER_STEP){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Cumulative Reward", "Time Step", "Cumulative Reward", colCSR, colCSRAvg);
       	}
       	else if(m == PerformanceMetric.CUMULATIVE_REWARD_PER_EPISODE){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Cumulative Reward", "Episode", "Cumulative Reward", colCER, colCERAvg);
       	}
       	else if(m == PerformanceMetric.AVERAGE_EPISODE_REWARD){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Average Reward", "Episode", "Average Reward", colAER, colAERAvg);
       	}
       	else if(m == PerformanceMetric.MEDIAN_EPISODE_REWARD){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Median Reward", "Episode", "Median Reward", colMER, colMERAvg);
       	}
       	else if(m == PerformanceMetric.CUMULATIVE_STEPS_PER_EPISODE){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Cumulative Steps", "Episode", "Cumulative Steps", colCSE, colCSEAvg);
       	}
       	else if(m == PerformanceMetric.STEPS_PER_EPISODE){
       		this.insertChart(plotContainer, c, columns, chartWidth, chartHeight, "Number of Steps", "Episode", "Number of Steps", colSE, colSEAvg);
       	}
       	
       	
       }
       
       int totalChartHeight = ((metrics.length / columns)+1)*(chartHeight+10);
       if(totalChartHeight > maxWindowHeight){
		JScrollPane scrollPane = new JScrollPane(plotContainer);
		scrollPane.setPreferredSize(new Dimension(chartWidth*columns+50, maxWindowHeight));
		this.add(scrollPane);
       }
       else{
       	this.add(plotContainer);
       }
	
}
 
Example #4
Source File: MultiAgentExperimenter.java    From burlap with Apache License 2.0 4 votes vote down vote up
/**
 * Starts the experiment and runs all trails for all agents.
 */
public void startExperiment(){

	if(this.completedExperiment){
		System.out.println("Experiment was already run and has completed. If you want to run a new experiment create a new Experiment object.");
		return;
	}
	
	if(this.plotter == null){
		
		TrialMode trialMode = TrialMode.MOST_RECENT_AND_AVERAGE;
		if(this.nTrials == 1){
			trialMode = TrialMode.MOST_RECENT_TRIAL_ONLY;
		}
		
		this.plotter = new MultiAgentPerformancePlotter(this.tf, 500, 250, 2, 500, trialMode);
			
	}
	
	if(this.displayPlots){
		this.plotter.startGUI();
	}
	
	for(int i = 0; i < this.nTrials; i++){
		
		DPrint.cl(this.debugCode, "Beginning trial " + (i+1) + "/" + this.nTrials);
		
		World w = worldGenerator.generateWorld();
		if(this.plotter != null){
			this.plotter.setWorld(w);
		}

		DPrint.toggleCode(w.getDebugId(), false);
		w.addWorldObserver(this.plotter);
		int id = 0;
		for(AgentFactoryAndType aft : this.agentFactoriesAndTypes){
			//aft.agentFactory.generateAgent().joinWorld(w, aft.at);
			w.join(aft.agentFactory.generateAgent("agent"+id, aft.at));
			id++;
		}
		
		this.plotter.startNewTrial();
		if(this.trialLengthIsInEpisodes){
			this.runEpisodewiseTrial(w);
		}
		else{
			this.runStepwiseTrial(w);
		}
		
	}
	
	this.plotter.endAllTrials();
	this.completedExperiment = true;
	
}
 
Example #5
Source File: PlotTest.java    From burlap_examples with MIT License 2 votes vote down vote up
public static void main(String [] args){

		GridWorldDomain gw = new GridWorldDomain(11,11); //11x11 grid world
		gw.setMapToFourRooms(); //four rooms layout
		gw.setProbSucceedTransitionDynamics(0.8); //stochastic transitions with 0.8 success rate

		//ends when the agent reaches a location
		final TerminalFunction tf = new SinglePFTF(
				PropositionalFunction.findPF(gw.generatePfs(), GridWorldDomain.PF_AT_LOCATION));

		//reward function definition
		final RewardFunction rf = new GoalBasedRF(new TFGoalCondition(tf), 5., -0.1);

		gw.setTf(tf);
		gw.setRf(rf);


		final OOSADomain domain = gw.generateDomain(); //generate the grid world domain

		//setup initial state
		GridWorldState s = new GridWorldState(new GridAgent(0, 0), new GridLocation(10, 10, "loc0"));



		//initial state generator
		final ConstantStateGenerator sg = new ConstantStateGenerator(s);


		//set up the state hashing system for looking up states
		final SimpleHashableStateFactory hashingFactory = new SimpleHashableStateFactory();


		/**
		 * Create factory for Q-learning agent
		 */
		LearningAgentFactory qLearningFactory = new LearningAgentFactory() {

			public String getAgentName() {
				return "Q-learning";
			}

			public LearningAgent generateAgent() {
				return new QLearning(domain, 0.99, hashingFactory, 0.3, 0.1);
			}
		};

		//define learning environment
		SimulatedEnvironment env = new SimulatedEnvironment(domain, sg);

		//define experiment
		LearningAlgorithmExperimenter exp = new LearningAlgorithmExperimenter(env,
				10, 100, qLearningFactory);

		exp.setUpPlottingConfiguration(500, 250, 2, 1000, TrialMode.MOST_RECENT_AND_AVERAGE,
				PerformanceMetric.CUMULATIVE_STEPS_PER_EPISODE, PerformanceMetric.AVERAGE_EPISODE_REWARD);


		//start experiment
		exp.startExperiment();


	}
 
Example #6
Source File: MultiAgentExperimenter.java    From burlap with Apache License 2.0 1 votes vote down vote up
/**
 * Setsup the plotting confiruation.
 * @param chartWidth the width of each chart/plot
 * @param chartHeight the height of each chart//plot
 * @param columns the number of columns of the plots displayed. Plots are filled in columns first, then move down the next row.
 * @param maxWindowHeight the maximum window height allowed before a scroll view is used.
 * @param trialMode which plots to use; most recent trial, average over all trials, or both. If both, the most recent plot will be inserted into the window first, then the average.
 * @param metrics the metrics that should be plotted. The metrics will appear in the window in the order that they are specified (columns first)
 */
public void setUpPlottingConfiguration(int chartWidth, int chartHeight, int columns, int maxWindowHeight, TrialMode trialMode, PerformanceMetric...metrics){
	
	this.plotter = new MultiAgentPerformancePlotter(this.tf, chartWidth, chartHeight, columns, maxWindowHeight, trialMode, metrics);
	this.plotter.setRefreshDelay(this.plotRefresh);
	this.plotter.setSignificanceForCI(this.plotCISignificance);
	
	
}