""" This file defines an Env, which is a collection of agents that share a timeline and a Space. """ import getpass import json import os from types import FunctionType # we mean to add logging soon! # import logging import indra.display_methods as disp import registry.registry as regis from registry.registry import get_prop from indra.agent import join, switch, Agent, AgentEncoder from indra.space import Space import traceback from indra.user import TEST, TestUser, USER_EXIT, APIUser from indra.user import TermUser, TERMINAL, API from indra.user import user_log_notif DEBUG = False DEBUG2 = False DEF_USER = "User" DEF_TIME = 10 UNLIMITED = 1000 X = 0 Y = 1 CENSUS_FUNC = "census_func" POP_HIST_HDR = "PopHist for " POP_SEP = ", " color_num = 0 def agent_by_name(agent): return agent if isinstance(agent, str) else agent.name class PopHist: """ Data structure to record the fluctuating numbers of various agent types. """ def __init__(self, serial_pops=None): self.pops = {} self.periods = 0 if serial_pops is not None: self.from_json(serial_pops) def __str__(self): s = POP_HIST_HDR for mbr in self.pops: s += mbr + POP_SEP return s def __repr__(self): return str(self) # for now! def __iter__(self): return iter(self.pops) def __getitem__(self, key): return self.pops[key] def add_period(self): self.periods += 1 def record_pop(self, mbr, count): if mbr not in self.pops: self.pops[mbr] = [] self.pops[mbr].append(count) def from_json(self, pop_data): self.periods = pop_data['periods'] self.pops = pop_data['pops'] def to_json(self): return {"periods": self.periods, "pops": self.pops} class Env(Space): """ A collection of entities that share a space and time. An env *is* a space and *has* a timeline (PopHist). That makes the inheritance work out as we want it to. There are four functions possibly passed in here: - census - line_data_func - pop_hist_func These will all be cutover to be attributes of the env: the handy new way to support serialization. """ def __init__(self, name, action=None, random_placing=True, serial_obj=None, exclude_member=None, census=None, line_data_func=None, pop_hist_setup=None, pop_hist_func=None, members=None, reg=True, **kwargs): super().__init__(name, action=action, random_placing=random_placing, serial_obj=serial_obj, reg=False, members=members, **kwargs) self.type = type(self).__name__ self.user_type = os.getenv("user_type", TERMINAL) # this func is only used once, so no need to restore it self.pop_hist_setup = pop_hist_setup self.num_switches = 0 if serial_obj is not None: # are we restoring env from json? self.restore_env(serial_obj) else: self.construct_anew(line_data_func, exclude_member, census, pop_hist_func) self.set_menu_excludes() # now we set our global singleton: regis.set_env(self) def set_attr(self, key, val): self.attrs[key] = val def get_attr(self, key, default=None): if key in self.attrs: return self.attrs[key] else: return default def set_menu_excludes(self): if not get_prop('use_line', True): self.exclude_menu_item("line_graph") if not get_prop('use_scatter', True): self.exclude_menu_item("scatter_plot") def construct_anew(self, line_data_func=None, exclude_member=None, census=None, pop_hist_func=None): self.pop_hist = PopHist() # this will record pops across time # Make sure varieties are present in the history if self.pop_hist_setup is not None: self.pop_hist_setup(self.pop_hist) else: for mbr in self.members: self.pop_hist.record_pop(mbr, self.pop_count(mbr)) self.plot_title = self.name self.user = None # these funcs will be stored as attrs... # but only if they're really funcs! # cause we're gonna try to call them if isinstance(census, FunctionType): print("Adding custom census func") self.attrs[CENSUS_FUNC] = census if isinstance(pop_hist_func, FunctionType): self.attrs["pop_hist_func"] = pop_hist_func if isinstance(line_data_func, FunctionType): self.attrs["line_data_func"] = line_data_func self.exclude_member = exclude_member self.womb = [] # for agents waiting to be born self.switches = [] # for agents waiting to switch groups self.handle_user_type() def handle_user_type(self): if self.user_type == TERMINAL: self.user = TermUser(getpass.getuser(), self) self.user.tell("Welcome to Indra, " + str(self.user) + "!") elif self.user_type == TEST: self.user = TestUser(getpass.getuser(), self) elif self.user_type == API: self.user = APIUser(getpass.getuser(), self) def from_json(self, serial_obj): super().from_json(serial_obj) self.pop_hist = PopHist(serial_pops=serial_obj["pop_hist"]) self.plot_title = serial_obj["plot_title"] nm = serial_obj["user"]["name"] msg = serial_obj["user"]["user_msgs"] self.user = APIUser(nm, self) self.user.tell(msg) self.name = serial_obj["name"] self.switches = serial_obj["switches"] self.womb = serial_obj["womb"] self.num_members_ever = serial_obj["num_members_ever"] def to_json(self): rep = super().to_json() rep["type"] = self.type rep["user"] = self.user.to_json() rep["plot_title"] = self.plot_title rep["pop_hist"] = self.pop_hist.to_json() rep["womb"] = self.womb rep["switches"] = self.switches rep["num_members_ever"] = self.num_members_ever return rep def __repr__(self): return json.dumps(self.to_json(), cls=AgentEncoder, indent=4, sort_keys=True) def restore_env(self, serial_obj): self.from_json(serial_obj) def exclude_menu_item(self, to_exclude): """ Just a pass-through call to our user object. """ self.user.exclude_menu_item(to_exclude) def get_periods(self): return self.pop_hist.periods def __call__(self): """ Calling the env makes it run. If we are on a terminal, we ask the user to put up a menu and choose. For tests, we just run N (default) turns. """ if self.action is not None: # the action was defined outside this class, so pass self: self.action(self) if (self.user is None) or (self.user_type == TEST): self.runN() else: while True: # run until user exit! if self.user() == USER_EXIT: break def add_member(self, member): """ Don't think we really need this here! It is just a pass-through call at present. Must examine further: eliminate if not needed. """ return super().add_member(member) def add_child(self, group): """ Put a child agent in the womb. group: which group will add new agent """ grp_nm = agent_by_name(group) self.womb.append(grp_nm) user_log_notif("An agent was added to the womb for " + grp_nm) def pending_switches(self): return str(len(self.switches)) def rpt_switches(self): return "# switches = " + self.pending_switches() + "; id: "\ + str(id(self.switches)) def add_switch(self, agent, from_grp, to_grp): """ Switch agent from 1 grp to another We allow the parameters to be passed as the names of the agents, or as the agents themselves. In the future, it should be just names. """ agent_nm = agent_by_name(agent) from_grp_nm = agent_by_name(from_grp) to_grp_nm = agent_by_name(to_grp) self.switches.append((agent_nm, from_grp_nm, to_grp_nm)) def handle_womb(self): """ The womb just contains group names -- they will be repeated as many times as that group needs to add members. We name the new members in the `member_creator()` method. This should be re-written as dict with: {"group_name": #agents_to_create} """ if self.womb is not None: for group_nm in self.womb: group = regis.get_group(group_nm) if group is not None and group.member_creator is not None: group.num_members_ever += 1 agent = group.member_creator("", group.num_members_ever) regis.register(agent.name, agent) join(group, agent) self.womb.clear() def handle_switches(self): if self.switches is not None: user_log_notif("In handle: " + self.rpt_switches()) user_log_notif("Switching " + self.pending_switches() + " agents between groups") for (agent_nm, from_grp_nm, to_grp_nm) in self.switches: switch(agent_nm, from_grp_nm, to_grp_nm) self.num_switches += 1 self.switches.clear() def handle_pop_hist(self): self.pop_hist.add_period() if "pop_hist_func" in self.attrs: self.attrs["pop_hist_func"](self.pop_hist) else: for mbr in self.pop_hist.pops: if mbr in self.members and self.is_mbr_comp(mbr): self.pop_hist.record_pop(mbr, self.pop_count(mbr)) else: self.pop_hist.record_pop(mbr, 0) def runN(self, periods=DEF_TIME): """ Run our model for N periods. Return the total number of actions taken. """ user_log_notif("Running env " + self.name + " for " + str(periods) + " periods.") num_acts = 0 num_moves = 0 for i in range(periods): # these things need to be done before action loop: self.handle_womb() self.handle_switches() self.handle_pop_hist() (a, m) = super().__call__() num_acts += a num_moves += m census_rpt = self.get_census(num_moves) self.user.tell(census_rpt) self.num_switches = 0 return num_acts def get_census(self, num_moves): """ Gets the census data for all the agents stored in the member dictionary. Takes in how many agent has moved from one place to another and how many agent has switched groups and returns a string of these census data. census_func overrides the default behavior. """ if CENSUS_FUNC in self.attrs: return self.attrs[CENSUS_FUNC](self) else: SEP_STR = "==================\n" census_str = ("\nTotal census for period " + str(self.get_periods()) + ":\n" + SEP_STR + "Group census:\n" + SEP_STR) for name in self.members: grp = self.members[name] population = len(grp) census_str += (" " + name + " (id: " + str(id(grp)) + "): " + str(population) + "\n") census_str += (SEP_STR + "Agent census:\n" + SEP_STR + " Agents who moved: " + str(num_moves) + "\n" + " Agents who switched groups: " + str(self.num_switches)) return census_str def has_disp(self): if not disp.plt_present: self.user.tell("ERROR: Graphing package encounters a problem: " + disp.plt_present_error_message) return False else: return True def line_graph(self): """ Show agent populations. """ if self.has_disp(): try: # TODO: improve implementation of the iterator of composite? period, data = self.line_data() if period is None: self.user.tell("No data to display.") return None line_plot = disp.LineGraph(self.plot_title, data, period, is_headless=self.headless(), attrs=self.attrs) line_plot.show() return line_plot except Exception as e: self.user.tell("Error when drawing line graph: " + str(e)) else: return None def scatter_graph(self): """ Show agent locations. """ if self.has_disp(): try: data = self.plot_data() scatter_plot = disp.ScatterPlot( self.plot_title, data, int(self.width), int(self.height), anim=True, data_func=self.plot_data, is_headless=self.headless(), attrs=self.attrs) scatter_plot.show() return scatter_plot except ValueError as e: # Exception as e: self.user.tell("Error when drawing scatter plot: " + str(e)) traceback.print_stack() else: return None def get_color(self, variety): if variety in self.members and self.members[variety].has_color(): return self.members[variety].get_color() else: global color_num color_num += 1 return disp.get_color(variety, color_num) def get_marker(self, variety): if variety in self.members: return self.members[variety].get_marker() else: return None def line_data(self): period = None if self.exclude_member is not None: exclude = self.exclude_member else: exclude = None if "line_data_func" in self.attrs: (period, data) = self.attrs["line_data_func"](self) else: data = {} for var in self.pop_hist.pops: if var != exclude: data[var] = {} data[var]["data"] = self.pop_hist.pops[var] data[var]["color"] = self.get_color(var) if not period: period = len(data[var]["data"]) return period, data def plot_data(self): """ This is the data for our scatter plot. This code assumes the env holds groups, and the groups hold agents with positions. This assumption is dangerous, and we should address it. """ if not disp.plt_present: self.user.tell("ERROR: Graphing package encountered a problem: " + disp.plt_present_error_message) return data = {} for variety in self.members: data[variety] = {} # matplotlib wants a list of x coordinates, and a list of y # coordinates: data[variety][X] = [] data[variety][Y] = [] data[variety]["color"] = self.members[variety].get_color() data[variety]["marker"] = self.members[variety].get_marker() current_variety = self.members[variety] for agent_nm in current_variety: # temp fix for one of the dangers mentioned above: # we might not be at the level of agents! if isinstance(current_variety[agent_nm], Agent): current_agent_pos = current_variety[agent_nm].pos if current_agent_pos is not None: (x, y) = current_agent_pos data[variety][X].append(x) data[variety][Y].append(y) return data def headless(self): return (self.user_type == API) or (self.user_type == TEST)