""" Implements a basic decision tree class, to be used in decision tree construction algorithms. """ import csv from collections import Counter class DTree(object): """ A decision tree object, consisting of a recursive set of decision tree nodes and basic parsing functions. """ def __init__(self, training_file): """ Initialize the decision tree from the given filename by parsing CSV data and setting necessary attributes. Args: filename: relative or absolute filepath to CSV file. CSV must follow format specified in README. Returns: An Decision Tree instance ready for learning with a decision tree creation algorithm. """ self.training_file = training_file self.root = None self.parse_csv() self.get_distinct_values() def parse_csv(self, dependent_index=-1): """ Set the object's attributes and data, where attributes is a list of attributes and data is an array of row dictionaries keyed by attribute. Also sets the dependent variable, which defaults to the last one. An option to change the position of this dependent variable has not yet been implemented. Args: dependent_index: the index to be specified as the dependent variable (default -1). Raises: NotImplementedError: If dependent_index is specified, since I haven't implemented that yet. """ if dependent_index != -1: raise NotImplementedError reader = csv.reader(self.training_file) attributes = reader.next() data = [] for row in reader: row = dict(zip(attributes, row)) data.append(row) self.training_file.close() self.dependent = attributes[dependent_index] self.attributes = [a for a in attributes if a != self.dependent] self.all_attributes = attributes self.data = data def get_distinct_values(self): """ Get the distinct values for each attribute in the CSV data. Returns: A dictionary with attribute keys and set values corresponding to the unique items in each attribute. """ values = {} for attr in self.all_attributes: # Use all attributes because ugly values[attr] = set(r[attr] for r in self.data) self.values = values def plot(self, x=1, y=1): """ Recursively plot the given node and its children with matplotlib Args: x: the desired width of the plot (default 1). y: the desired height of the plot (default 1). Raises: NotImplementedError: Not yet implemented. """ self.root._plot() def decide(self, attributes): """ Make a decision on the dependent variable of the tree given the provided attributes. Args: attributes: the list of independent attributes, correctly ordered, with which to make a decision on the dependent value. Returns: A dependent variable representing the decision tree decision. Raises: ValueError: if an invalid property is found which is not represented in the decision tree. """ if len(attributes) != len(self.attribute_order): print self.attribute_order raise ValueError("supplied attributes do not match data") attrs_dict = dict(zip(self.attribute_order, attributes)) return self.root._decide(attrs_dict) def test_file(self, testing_file, csv=None): """ Test the given CSV file on this instance's decision tree, either printing decisions to stdout or writing to a csv file. Note: Testing CSV files must have the same format as training CSV files, including column order. Repeated headers are optional. Args: testing_file: testing CSV file. Testing CSV files must have the same format as the training CSV files! this function will automatically close the file after usage. csv: if specified, will write to the given CSV file. """ import csv reader = csv.reader(testing_file) first_row = reader.next() # If first row if first_row == self.all_attributes or first_row == self.attributes: test_data = [] else: test_data = [dict(zip(self.all_attributes, first_row))] for row in reader: row = dict(zip(self.all_attributes, row)) test_data.append(row) testing_file.close() correct = 0. # Keep track of statistics for row in test_data: formatted = [row[a] for a in self.attributes] decision = self.decide(formatted) try: expected_str = "(expected {0})".format(row[self.dependent]) if row[self.dependent] == decision: correct += 1 expected_str += ", CORRECT" else: expected_str += ", INCORRECT" except KeyError: expected_str = "" print "{0} -> {1} {2}".format(formatted, decision, expected_str) print "% correct: {0}".format(correct/len(test_data)) def filter_subset(self, subset, attr, value): """ Filter a subset of CSV data further by selecting only the rows of subset which have the given attribute and value. Args: subset: the subset of the CSV data to filter upon. attr: the attribute of the value to filter upon. value: the value to filter upon. Returns: A list of the filtered rows according to the attribute and value. """ return [r for r in subset if r[attr] == value] def value_counts(self, subset, attr, value, base=False): """ Get the number of currences per value of the dependent variable when the given attribute is equal to the given value. FIXME: Can attr/value be eliminated?? Args: subset: the subset with which to act upon. attr: the attribute of the value. value: the value with which to track counts. base: whether or not to calculate values based on the dependent value (default False). Returns: A Counter instance detailing the number of occurrences per dependent variable. """ counts = Counter() for row in subset: if row[attr] == value or base: counts[row[self.dependent]] += 1 return counts def rules(self): """ Return all of the node's tree branch traversals, which can be used as if/then rules for simulating the decision process. Returns: A 2d list of all known tree branch traversals. """ return sorted( self.root._rules(), key=lambda t: (len(t), [p[1] for p in t if isinstance(p, tuple)]) ) def set_attributes(self, attributes): """ Set the correct order of the attributes in the decision tree based on the parsed CSV data. Args: attributes: the correctly ordered list of independent attributes from the CSV data. """ self.attribute_order = attributes def attr_counts(self, subset, attr): """ Get the number of occurrences per value of the given attribute Args: subset: the subset with which to act upon. attr: the selected attribute. Returns: A Counter instance detailing the number of occurrences per attribute value. """ counts = Counter() for row in subset: counts[row[attr]] += 1 return counts @property def depth(self): """ Return the maximum depth of the tree assuming the current node as the parent. Returns: An integer calculated from the longest tree branch traversal. """ return self.root._depth(0) @property def num_leaves(self): """ Return the total number of leaves for the current tree. Returns: An integer of the number of leaves. """ # FIXME: Not safe for an ID3 for which tree has not been created if self.root.leaf: return 1 else: return sum(c._num_leaves for c in self.root.children) @property def distinct_values(self): """ Returns a readable list of all values in the CSV data set. Returns: A flattened list of all distinct values. """ values_list = [] for s in self.values.values(): for val in s: values_list.append(val) return values_list def __str__(self): """ Return the filename of the decision tree, the dependent variable, and the string representation of the decision tree decision tree. """ return "decision tree for {0}:\nDependent variable: {1}\n{2}".format( self.training_file.name, self.dependent, self.root ) def __repr__(self): """ Return the filename of the decision tree and other useful diagnostics. """ return ("decision tree for {0}:\nDependent variable: {1}\n{2}\n" + "Rows: {3}\nValues: {4}\nBase Data Entropy: {5}").format( self.training_file.name, self.dependent, repr(self.root), len(self.data), self.values, self.get_base_entropy(self.data) ) def decision_repl(self): """ An interactive REPL for making decisions based on the created decision tree. """ print print ','.join("{{{0}}}".format(a) for a in self.attributes) print "Decision tree REPL. Enter above parameters separated by commas," print "no spaces between commas or brackets." while True: x = raw_input('> ').split(',') print "{0} ->".format(x) try: print self.decide(x) except Exception as e: print "Error with decision: {0}".format(e) class DTreeNode(object): """ A recursively defined decision tree node. """ def __init__(self, label, parent_value=None, properties={}, leaf=False): """ Initialize a decision tree node. Args: label: the label of the node, which can either be a decision attribute or a leaf result. parent_value: the name of the link from the current node to its parent (default None, used in cases of root nodes). properties: a JSON-like dictionary containing various diagnostic properties of the given node (e.g. information gain or entropy) (default empty dictionary). leaf: a boolean indicating whether or not this node is a leaf node (default False). """ self.label = label self.children = [] self.parent_value = parent_value self.properties = properties self.leaf = leaf def _plot(self, xoffset, yoffset): """ Plot the given node at the xoffset and yoffset coordinates. Args: xoffset: the x coordinate for plotting of the given node. yoffset: the y coordinate for plotting of the given node. Raises: NotImplementedError: Not yet implemented. """ raise NotImplementedError def _decide(self, attrs_dict): """ Recursively decide using the given attribute/value dictionary. Internal function is separated from the more friendly decide() method. """ if self.leaf: return self.label val = attrs_dict[self.label] for node in self.children: if val == node.parent_value: return node._decide(attrs_dict) raise ValueError("Invalid property found: {0}".format(val)) def add_child(self, node): """ Add the given child node to the list of children of the current node. Args: node: the DTree node to be appended as a child. """ self.children.append(node) @property def num_children(self): """ Return the total number of immediate child nodes under the current node. Returns: An integer of the number of children. """ return len(self.children) @property def _num_leaves(self): """ Return the total number of leaves that exist under the current node. """ if self.leaf: return 1 else: return sum(c.num_leaves for c in self.children) def _depth(self, init): """ Accumulate the depth of the tree at the given node taking into account the previous depth of the tree. init is the existing depth accumulated from previous levels of the tree. """ if self.leaf: return init else: return max(c._depth(init+1) for c in self.children) def _rules(self, parent=None, previous=()): """ Return a 2d list of decision rules with the given parent node and the tuple of previous nodes. """ # import pdb; pdb.set_trace() rows = [] if parent is not None: previous += ((parent.label, self.parent_value), ) if self.leaf: previous += ((self.label), ) rows.append(previous) else: for node in self.children: rows.extend(node._rules(self, previous)) return rows def __str__(self): """ Recursively build a string representation of the tree starting at the current node. """ return "--{0}--({1}, {2})".format( self.parent_value, self.label, ', '.join(str(c) for c in self.children) ) def __repr__(self): """ Recursively build a string representation of the tree starting at the current node. Differs from __str__ by including additional diagnostic information. """ return "--{0}--({1} {2}, {3})".format( self.parent_value, self.label, self.properties, ', '.join(repr(c) for c in self.children) )