# Logan Noel (github.com/lmnoel)
#
# ©2017-2019, Center for Spatial Data Science

import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot
import json
from spatial_access.p2p import TransitMatrix

from spatial_access.SpatialAccessExceptions import UnrecognizedCategoriesException
from spatial_access.SpatialAccessExceptions import SourceDataNotFoundException
from spatial_access.SpatialAccessExceptions import DestDataNotFoundException
from spatial_access.SpatialAccessExceptions import SourceDataNotParsableException
from spatial_access.SpatialAccessExceptions import DestDataNotParsableException
from spatial_access.SpatialAccessExceptions import PrimaryDataNotFoundException
from spatial_access.SpatialAccessExceptions import SecondaryDataNotFoundException
from spatial_access.SpatialAccessExceptions import ShapefileNotFoundException
from spatial_access.SpatialAccessExceptions import ModelNotAggregatedException
from spatial_access.SpatialAccessExceptions import ModelNotCalculatedException
from spatial_access.SpatialAccessExceptions import ModelNotAggregatableException
from spatial_access.SpatialAccessExceptions import SpatialIndexNotMatchedException
from spatial_access.SpatialAccessExceptions import TooManyCategoriesToPlotException
from spatial_access.SpatialAccessExceptions import UnexpectedPlotColumnException
from spatial_access.SpatialAccessExceptions import AggregateOutputTypeNotExpectedException
from spatial_access.SpatialAccessExceptions import UnexpectedAggregationTypeException


import os.path
import logging


class ModelData:
    """
    Common resources for spatial_access.Models.
    """
    def __init__(self, network_type, sources_filename,
                 destinations_filename,
                 source_column_names=None, dest_column_names=None,
                 configs=None, debug=False):
        """
        Args:
            network_type: string, one of {'walk', 'bike', 'drive', 'otp'}.
            sources_filename: string, csv filename.
            destinations_filename: string, csv filename.
            source_column_names: dictionary, map column names to expected values.
            dest_column_names: dictionary, map column names to expected values.
            configs: defaults to None, else pass in an instance of Configs to override
                default values for p2p.
            debug: boolean, enable to see more detailed logging output.
        """
        self.network_type = network_type
        self.transit_matrix = None
        self.dests = None
        self.sources = None
        self.model_results = None
        self.aggregated_results = None
        self.all_categories = {}
        self.focus_categories = {}

        self.configs = configs

        self._aggregation_args = {}
        self._is_source = True
        self._is_aggregatable = True
        self._requires_user_aggregation_type = False
        self._result_column_names = None

        self.sources_filename = sources_filename
        self.destinations_filename = destinations_filename

        # column_names and file_hints are similar, both map intended_name->actual_data_name
        # the difference is column names should be complete/contain all needed fields
        self.source_column_names = source_column_names
        self.dest_column_names = dest_column_names

        # hints are partial/potentially incomplete, and supplied
        # by p2p.TransitMatrix
        self._source_file_hints = None
        self._dest_file_hints = None

        self.sources_in_range = {}
        self.dests_in_range = {}

        # initialize logger
        self.debug = debug
        self.logger = None
        self.set_logging(debug)


    def write_transit_matrix_to_csv(self, filename=None):
        """
        Args:
            filename: string (or none, in which case a filename will
            be automatically generated).
        Write transit matrix to csv.
        """
        self.transit_matrix.write_csv(filename)

    def write_transit_matrix_to_tmx(self, filename=None):
        """
        Args:
            filename: string (or none, in which case a filename will
            be automatically generated)
        Write transit matrix to tmx.
        """
        self.transit_matrix.write_tmx(filename)

    @staticmethod
    def _get_output_filename(keyword, extension='csv', file_path='data/'):
        """
        Args:
            keyword: string such as "model_results" or "aggregated_results"
                to build the filename.
            extension: file type extension (no ".")
            file_path: subdirectory.
        Returns: string of unused filename.
        """
        if file_path is None:
            file_path = "data/"
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        filename = os.path.join(file_path, '{}_0.{}'.format(keyword, extension))
        counter = 1
        while os.path.isfile(filename):
            filename = os.path.join(file_path, '{}_{}.{}'.format(keyword, counter, extension))
            counter += 1

        return filename

    def get_population(self, source_id):
        """
        Args:
            source_id: string/int
        Returns: the population at a source point.
        """
        return self.sources.loc[source_id, 'population']

    def get_capacity(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the capacity value at a dest point.
        """
        return self.dests.loc[dest_id, 'capacity']

    def get_category(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the category value at a dest point.
        """
        return self.dests.loc[dest_id, 'category']

    def get_all_dest_ids(self):
        """
        Returns: all ids of destination data frame.
        """
        return list(self.dests.index)

    def get_all_source_ids(self):
        """
        Returns: all ids of source data frame.
        """
        return list(self.sources.index)

    def get_ids_for_category(self, category):
        """
        Given category, return an array of all indeces
        which match. If category is all_categories, return all indeces.
        """
        return list(self.dests[self.dests['category'] == category].index)

    def get_transit_source_ids(self):
        """
        Returns: all source IDs from transit matrix
        """
        return self.transit_matrix.matrix_interface.get_source_ids()

    def get_transit_dest_ids(self):
        """
        Returns: all destination IDs from transit matrix
        """
        return self.transit_matrix.matrix_interface.get_dest_ids()

    def get_common_source_ids(self):
        """
        Get source IDs present in provided source data and transit matrix
        """
        return set(self.get_all_source_ids()) & set(self.get_transit_source_ids())

    def get_common_dest_ids(self):
        """
        Get destination IDs present in provided destination data and transit matrix
        """
        return set(self.get_all_dest_ids()) & set(self.get_transit_dest_ids())

    def _missing_transit_data_warning(self, id_type):
        """
        Throws an error to the log if there are source or destination IDs in your data
        without corresponding time travel data in the transit matrix.
        """
        for type in id_type:
            assert type in ['source', 'destination'], "id_type must be 'source', or 'destination'"

            data_ids     = set(self.get_all_dest_ids())     if type == 'destination' else set(self.get_all_source_ids())
            transit_ids  = set(self.get_transit_dest_ids()) if type == 'destination' else set(self.get_transit_source_ids())

            no_transit_data = data_ids - transit_ids

            if len(no_transit_data):
                self.logger.warning('No transit matrix data available for {} {} ID(s). Missing: {}'
                                .format(len(no_transit_data), type, no_transit_data))

    def set_logging(self, debug):
        """
        Args:
            debug: set to true for more detailed logging
                output.
        """

        if debug:
            logging.basicConfig(level=logging.DEBUG)
        else:
            logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

    def load_transit_matrix(self, read_from_file=None):
        """
        Load the transit matrix (and sources/dests).
        Args:
            read_from_file: filename of a tmx or csv file to load.
                This allows the user to bypass computing the
                transit matrix from scratch. If read_from_file is
                None, the user will be directed to compute the
                transit matrix from given source/dest data.
        Raises:
            SourceDataNotFoundException: Cannot find source data.
            DestDataNotFoundException: Cannot find dest data.
        """
        if read_from_file:
            self.transit_matrix = TransitMatrix(self.network_type,
                                                read_from_file=read_from_file,
                                                debug=self.debug)
        else:
            self.transit_matrix = TransitMatrix(self.network_type,
                                                primary_input=self.sources_filename,
                                                secondary_input=self.destinations_filename,
                                                primary_hints=self.source_column_names,
                                                secondary_hints=self.dest_column_names,
                                                configs=None,
                                                debug=self.debug)
            try:
                self.transit_matrix.process()
            except PrimaryDataNotFoundException:
                raise SourceDataNotFoundException()
            except SecondaryDataNotFoundException:
                raise DestDataNotFoundException()

            # borrow hints for use in load_sources() and load_dests() if not user supplied
            if self._source_file_hints is None:
                self._source_file_hints = self.transit_matrix.primary_hints
            if self._dest_file_hints is None:
                self._dest_file_hints = self.transit_matrix.secondary_hints

        self.reload_sources()
        self.reload_dests()

    def reload_sources(self, filename=None):
        """
        Load the source points for the model (from csv).
        For each point, the table should contain:
        -unique identifier (integer or string)
        -latitude & longitude
        -population (integer) [only for some models]

        Args:
            filename: string
        Raises:
            SourceDataNotFoundException: Cannot find source
                data.
            SourceDataNotParsableException: Provided source_column_names
                do not correspond to column names.
        """

        if filename:
            self.sources_filename = filename
        try:
            self.sources = pd.read_csv(self.sources_filename)
        except FileNotFoundError:
            raise SourceDataNotFoundException()

        if self.source_column_names is None:
            # extract the column names from the table
            population = ''
            idx = ''
            lat = ''
            lon = ''

            if self._source_file_hints is not None:
                if 'idx' in self._source_file_hints:
                    idx = self._source_file_hints['idx']
                if 'population' in self._source_file_hints:
                    population = self._source_file_hints['population']
                if 'lat' in self._source_file_hints:
                    lat = self._source_file_hints['lat']
                if 'lon' in self._source_file_hints:
                    lon = self._source_file_hints['lon']

            # extract the column names from the table for whichever fields
            # were not gleaned from self.source_file_hints
            source_data_columns = self.sources.columns.values
            print('The variables in your data set are:')
            for var in source_data_columns:
                print('> ', var)
            while idx not in source_data_columns:
                idx = input('Enter the unique index variable: ')
            print('If you have no population variable, write "skip" (no quotations)')
            while population not in source_data_columns and population != 'skip':
                population = input('Enter the population variable: ')
            while lat not in source_data_columns:
                lat = input('Enter the latitude variable: ')
            while lon not in source_data_columns:
                lon = input('Enter the longitude variable: ')

            # store the col names for later use
            self.source_column_names = {'lat': lat, 'lon': lon, 'idx': idx,
                                        'population': population}

        try:
            # insert filler values for the population column if
            # user does not want to include it. need it for coverage
            if self.source_column_names['population'] == 'skip':
                self.sources['population'] = 1

            # rename columns, clean the data frame
            rename_cols = {self.source_column_names['population']: 'population',
                           self.source_column_names['lat']: 'lat',
                           self.source_column_names['lon']: 'lon'}
            self.sources.set_index(self.source_column_names['idx'], inplace=True)
            self.sources.rename(columns=rename_cols, inplace=True)
        except KeyError:
            raise SourceDataNotParsableException()

        # drop unused columns
        columns_to_keep = list(rename_cols.values())
        self.sources = self.sources[columns_to_keep]

    def reload_dests(self, filename=None):
        """
        Load the dest points for the model (from csv).
        For each point, the table should contain:
        -unique identifier (integer or string)
        -latitude & longitude
        -category (string/int) [only for some models]
        -capacity (numeric) [only for some models]

        Args:
            filename: string
        Raises:
            DestDataNotFoundException: Cannot find dest
                data.
            DestDataNotParsableException: Provided dest_column_names
                do not correspond to column names.
        """

        if filename:
            self.destinations_filename = filename

        try:
            self.dests = pd.read_csv(self.destinations_filename)
        except FileNotFoundError:
            raise DestDataNotFoundException()

        if self.dest_column_names is None:
            # extract the column names from the table
            category = ''
            capacity = ''
            idx = ''
            lat = ''
            lon = ''

            if self._dest_file_hints is not None:
                if 'idx' in self._dest_file_hints:
                    idx = self._dest_file_hints['idx']
                if 'category' in self._dest_file_hints:
                    category = self._dest_file_hints['category']
                if 'capacity' in self._dest_file_hints:
                    capacity = self._dest_file_hints['capacity']
                if 'lat' in self._dest_file_hints:
                    lat = self._dest_file_hints['lat']
                if 'lon' in self._dest_file_hints:
                    lon = self._dest_file_hints['lon']

            # extract the column names from the table for whichever fields
            # were not gleaned from self.dest_file_hints
            dest_data_columns = self.dests.columns.values
            print('The variables in your data set are:')
            for var in dest_data_columns:
                print('> ', var)
            while idx not in dest_data_columns:
                idx = input('Enter the unique index variable: ')
            print('If you have no capacity variable, write "skip" (no quotations)')
            while capacity not in dest_data_columns and capacity != 'skip':
                capacity = input('Enter the capacity variable: ')
            print('If you have no category variable, write "skip" (no quotations)')
            while category not in dest_data_columns and category != 'skip':
                category = input('Enter the category variable: ')
            while lat not in dest_data_columns:
                lat = input('Enter the latitude variable: ')
            while lon not in dest_data_columns:
                lon = input('Enter the longitude variable: ')
            self.dest_column_names = {'lat': lat, 'lon': lon, 'idx': idx,
                                      'category': category, 'capacity': capacity}

        try:
            # insert filler values for the capacity column if
            # user does not want to include it.
            if self.dest_column_names['capacity'] == 'skip':
                self.dests['capacity'] = 1

            # insert filler values for the category column if
            # user does not want to include it.
            if self.dest_column_names['category'] == 'skip':
                self.dests['category'] = 1

            # rename columns, clean the data frame
            rename_cols = {self.dest_column_names['lat']: 'lat', self.dest_column_names['lon']: 'lon'}
            if self.dest_column_names['capacity'] != 'skip':
                rename_cols[self.dest_column_names['capacity']] = 'capacity'
            if self.dest_column_names['category'] != 'skip':
                rename_cols[self.dest_column_names['category']] = 'category'

            self.dests.set_index(self.dest_column_names['idx'], inplace=True)
            self.dests.rename(columns=rename_cols, inplace=True)

        except KeyError:
            raise DestDataNotParsableException()

        # drop unused columns
        columns_to_keep = list(rename_cols.values())
        self.dests = self.dests[columns_to_keep]
        self.all_categories = set(self.dests['category'])

    def get_dests_in_range_of_source(self, source_id):
        """
        Args:
            source_id: string/int
        Returns: a list of dest ids in range of the source.
        """
        return self.dests_in_range[source_id]

    def get_sources_in_range_of_dest(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: a list of source ids in range of the dest.
        """
        return self.sources_in_range[dest_id]

    def calculate_dests_in_range(self, upper_threshold):
        """
        Args:
            upper_threshold: numeric, upper threshold of what
                points are considered to be in range.
        """
        self.dests_in_range = self.transit_matrix.matrix_interface.get_dests_in_range(upper_threshold)

    def calculate_sources_in_range(self, upper_threshold):
        """
        Args:
            upper_threshold: numeric, upper threshold of what
                points are considered to be in range.
        """
        self.sources_in_range = self.transit_matrix.matrix_interface.get_sources_in_range(upper_threshold)

    def get_values_by_source(self, source_id, sort=False):
        """
        Args:
            source_id: string/int
            sort: boolean, set to true for return value
                to be sorted in nondecreasing order.
        Returns: list of (dest_id, value) pairs, with the option
            to sort in increasing order by value.
        """
        return self.transit_matrix.matrix_interface.get_values_by_source(source_id, sort)

    def get_values_by_dest(self, dest_id, sort=False):
        """
        Args:
            dest_id: string/int
            sort: boolean, set to true for return value
                to be sorted in nondecreasing order.
        Returns: a list of (source_id, value) pairs, with the option
            to sort in increasing order by value.
        """
        return self.transit_matrix.matrix_interface.get_values_by_dest(dest_id, sort)

    def get_population_in_range(self, dest_id):
        """
        Args:
            dest_id: string/int
        Returns: the population within the capacity range for the given
            destination id.
        """
        cumulative_population = 0
        for source_id in self.get_sources_in_range_of_dest(dest_id):
            source_population = self.get_population(source_id)
            if source_population > 0:
                cumulative_population += source_population

        return cumulative_population

    def _map_categories_to_sp_matrix(self):
        """
        Map all categories-> associated dest_ids.
        """
        for dest_id in self.get_common_dest_ids():
            associated_category = self.get_category(dest_id)
            self._add_to_category_map(dest_id, associated_category)

    def _add_to_category_map(self, dest_id, category):
        """
        Args:
            dest_id: string/int
            category: string
        Map the dest_id to the category in the
            transit matrix.
        """
        self.transit_matrix.matrix_interface.add_to_category_map(dest_id, category)

    def time_to_nearest_dest(self, source_id, category):
        """
        Args:
            source_id: string/int
            category: string
        Returns: the time to nearest destination for source_id
            of type category. If category is 'all_categories', return
            the time to nearest destination of any type.
        """
        if category == 'all_categories':
            return self.transit_matrix.matrix_interface.time_to_nearest_dest(source_id, None)
        else:
            return self.transit_matrix.matrix_interface.time_to_nearest_dest(source_id, category)

    def count_dests_in_range_by_categories(self, source_id, category, upper_threshold):
        """
        Args:
            source_id: int/string
            category: string
            upper_threshold: numeric, upper limit of what is
                considered to be 'in range'.
        Returns: the count of destinations in range
            of the source id per category
        """
        if category == 'all_categories':
            return self.transit_matrix.matrix_interface.count_dests_in_range(source_id,
                                                                             upper_threshold,
                                                                             None)
        else:
            return self.transit_matrix.matrix_interface.count_dests_in_range(source_id,
                                                                             upper_threshold,
                                                                             category)

        # TODO: optimize this method
    def count_sum_in_range_by_categories(self, source_id, category):
        """
        Args:
            source_id: int/string
            category: string
        Returns: the count of destinations in range
            of the source id per category
        """
        running_sum = 0
        for dest_id in self.get_dests_in_range_of_source(source_id):
            if self.get_category(dest_id) == category or category == 'all_categories':
                running_sum += self.get_capacity(dest_id)
        return running_sum

    def _print_data_frame(self):
        """
        Print the transit matrix.
        Don't call this for anything other
        than trivially small matrices.
        """
        self.transit_matrix.matrix_interface.print_data_frame()

    def _spatial_join_community_index(self, dataframe, shapefile='data/chicago_boundaries/chicago_boundaries.shp',
                                      spatial_index='community',  projection='epsg:4326'):
        """
        Join the dataframe with location data from shapefile.
        Args:
            dataframe: pandas dataframe with unique id.
            shapefile: shapefile containing geometry.
            spatial_index: column names of aggregation area in shapefile.
            projection: defaults to 'epsg:4326'
        Returns: dataframe.
        Raises:
            ShapefileNotFoundException: Shapefile not found.
            SpatialIndexNotMatchedException: spatial_index not found in shapefile.
        """
        geometry = [Point(xy) for xy in zip(dataframe['lon'], dataframe['lat'])]
        crs = {'init': projection}
        geo_original = gpd.GeoDataFrame(dataframe, crs=crs, geometry=geometry)
        try:
            boundaries_gdf = gpd.read_file(shapefile)
        except FileNotFoundError:
            raise ShapefileNotFoundException('shapefile not found: {}'.format(shapefile))

        geo_result = gpd.sjoin(boundaries_gdf, geo_original, how='right',
                               op='intersects')

        dataframe_columns = list(dataframe.columns)

        geo_result.rename(columns={spatial_index: 'spatial_index'}, inplace=True)
        dataframe_columns.append('spatial_index')
        dataframe_columns.append('geometry')
        try:
            geo_result = geo_result[dataframe_columns]
        except KeyError:
            raise SpatialIndexNotMatchedException('Unable to match spatial_index:{}'.format(spatial_index))
        if len(geo_result) != len(dataframe):
            self.logger.warning('Length of joined dataframe ({}) != length of input dataframe ({})'
                                .format(len(geo_result), len(dataframe)))
        return geo_result

    def _rejoin_results_with_coordinates(self, model_results, is_source):
        """
        Args:
            model_results: dataframe
            is_source: boolean (tells where to draw lat/long data from)
        Returns: deep copty of dataframe with lat/long data.
        """
        model_results = model_results.copy(deep=True)
        if is_source:
            model_results['lat'] = self.sources['lat']
            model_results['lon'] = self.sources['lon']
        else:
            model_results['lat'] = self.dests['lat']
            model_results['lon'] = self.dests['lon']
        return model_results

    def _build_aggregate(self, data_frame, aggregation_args, shapefile, spatial_index, projection):
        """
        Private method invoked to aggregate dataframe on spatial area.
        Args:
            data_frame: dataframe
            aggregation_args: dictionary mapping each column name to the method by which that
                column should be aggregated, e.g. mean, sum, etc...
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            projection: defaults to 'epsg:4326'

        Returns: aggregated data frame.

        """
        if 'lat' not in data_frame.columns or 'lon' not in data_frame.columns or 'spatial_index' not in data_frame.columns:
            data_frame = self._spatial_join_community_index(dataframe=data_frame,
                                                                    shapefile=shapefile,
                                                                    spatial_index=spatial_index,
                                                                    projection=projection)
        aggregated_data_frame = data_frame.groupby('spatial_index').agg(aggregation_args)
        return aggregated_data_frame

    def set_focus_categories(self, categories):
        """
        Set the categories that the model should perform computations for.
        Args:
            categories: list of categories.
        Raises:
            UnrecognizedCategoriesException: User passes categories not
                found in the dest data.
        """
        if categories is None:
            self.focus_categories = self.all_categories
        else:
            self.focus_categories = categories
            unrecognized_categories = set(categories) - self.all_categories
            if len(unrecognized_categories) > 0:
                raise UnrecognizedCategoriesException(','.join([category for category in unrecognized_categories]))


    def aggregate(self, aggregation_type=None, shapefile='data/chicago_boundaries/chicago_boundaries.shp',
                        spatial_index='community',  projection='epsg:4326'):
        """

        Args:
            aggregation_type: string, required for models with multiple possiblities
                for aggregating.
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            projection: defaults to 'epsg:4326'

        Returns: aggregated data frame.

        Raises:
            ModelNotCalculatedException: If the model has not yet been
                calculated.
            UnexpectedAggregationTypeException: If the user passes an
                unexpected aggregation type, or no aggregation type
                when one is expected, or an aggregation type when none
                is expected.
        """
        if self.model_results is None:
            raise ModelNotCalculatedException()
        if not self._is_aggregatable:
            raise ModelNotAggregatableException()
        if self._requires_user_aggregation_type:
            if aggregation_type is None:
                raise UnexpectedAggregationTypeException(aggregation_type)
            else:
                if aggregation_type not in {'min', 'max', 'mean'}:
                    raise UnexpectedAggregationTypeException(aggregation_type)
                else:
                    self._aggregation_args = {}
                    for column in self.model_results.columns:
                        self._aggregation_args[column] = aggregation_type

        else:
            if aggregation_type is not None:
                raise UnexpectedAggregationTypeException(aggregation_type)

        results_with_coordinates = self._rejoin_results_with_coordinates(self.model_results, self._is_source)

        self.aggregated_results = self._build_aggregate(data_frame=results_with_coordinates,
                                                        aggregation_args=self._aggregation_args,
                                                        shapefile=shapefile,
                                                        spatial_index=spatial_index,
                                                        projection=projection)
        return self.aggregated_results

    def write_aggregated_results(self, filename=None, output_type='csv'):
        """
        Args:
            filename: file to write results. If none is given, a valid
                filename will be automatically generated.
            output_type: 'csv' or 'json'.

        Raises:
            AggregateOutputTypeNotExpectedException:

        """
        if filename is not None:
            output_type = filename.split('.')[1]
        else:
            filename = self._get_output_filename(keyword='aggregate',
                                                       extension=output_type,
                                                       file_path='data/')

        if self.aggregated_results is None:
            raise ModelNotAggregatedException()

        if output_type == 'csv':
            self.aggregated_results.to_csv(filename)
        elif output_type == 'json':
            output = {}
            for row in self.aggregated_results.itertuples():
                output[row[0]] = {}
                for i, column in enumerate(self.aggregated_results.columns):
                    output[row[0]][column] = row[i + 1]
            with open(filename, 'w') as file:
                json.dump(output, file)
        else:
            raise AggregateOutputTypeNotExpectedException(output_type)

    def write_results(self, filename=None):
        """
        Write results to csv.
        Args:
            filename: file to write results. If none is given, a valid
            filename will be automatically generated.
        Raises:
            ModelNotCalculatedException: if model has not been calculated.
        """
        if self.model_results is None:
            raise ModelNotCalculatedException()
        if filename is None:
            filename = self._get_output_filename(keyword='model',
                                                extension='csv',
                                                file_path='data/')
        self.model_results.to_csv(filename)


    @staticmethod
    def _join_aggregated_data_with_boundaries(aggregated_results, spatial_index,
                                              shapefile='data/chicago_boundaries/chicago_boundaries.shp'):
        """
        Args:
            aggregated_results: dataframe
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile

        Returns: dataframe.

        Raises:
            ShapefileNotFoundException: shapefile not found.
        """
        try:
            boundaries_gdf = gpd.read_file(shapefile)
        except FileNotFoundError:
            raise ShapefileNotFoundException('shapefile not found: {}'.format(shapefile))
        columns_to_keep = list(aggregated_results.columns)
        columns_to_keep.append('geometry')
        columns_to_keep.append(spatial_index)

        results = boundaries_gdf.merge(aggregated_results, left_on=spatial_index,
                                       right_index=True, how='outer')
        results.fillna(value=0, inplace=True)
        return results[columns_to_keep]

    def plot_cdf(self, plot_column=None, xlabel="xlabel", ylabel="ylabel", title="title",
                 bins=100, is_density=False, filename=None):
        """
        Args:
            plot_column: If the model has multiple possibilities to plot, specify which
                one.
            xlabel: xlabel for figure.
            ylabel: ylabel for figure.
            title: title for figure.
            bins: integer, number of bins.
            is_density: boolean, true for density plot.
            filename: filename to write to.

        Raises:
            ModelNotAggregatedException: Model is not aggregated.
            UnexpectedPlotColumnException: User passes unexpected plot type.
            TooManyCategoriesToPlotException: Too many categories to plot.
        """
        if self.aggregated_results is None:
            raise ModelNotAggregatedException()

        if self._is_source:
            cdf_eligible = self.model_results[self.sources['population'] > 0]
        else:
            cdf_eligible = self.model_results

        if isinstance(self._result_column_names, str):
            if plot_column is not None:
                raise UnexpectedPlotColumnException(plot_column)
            plot_column = self._result_column_names
        else:
            if plot_column is None:
                raise UnexpectedPlotColumnException(plot_column)

        # initialize block parameters
        mpl.pyplot.close()
        mpl.pyplot.rcParams['axes.facecolor'] = '#cfcfd1'
        fig, ax = mpl.pyplot.subplots(figsize=(8, 4))
        ax.grid(zorder=0)

        available_colors = ['black', 'magenta', 'lime', 'red', 'black', 'orange', 'grey', 'yellow', 'brown', 'teal']
        color_keys = []
        for column in cdf_eligible.columns:
            if plot_column not in column:
                continue
            x = cdf_eligible[column]
            try:
                color = available_colors.pop(0)
            except IndexError:
                raise TooManyCategoriesToPlotException()
            patch = mpatches.Patch(color=color, label=column)
            color_keys.append(patch)
            n, bins, blah = ax.hist(x, bins, density=is_density, histtype='step',
                                    cumulative=True, label=column, color=color, zorder=3)
        ax.legend(loc='right', handles=color_keys)

        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        if filename is None:
            filename = self._get_output_filename(keyword='figure', extension='png',
                                                file_path='figures/')
        mpl.pyplot.savefig(filename, dpi=400)
        self.logger.info('Plot was saved to: {}'.format(filename))

    def plot_choropleth(self, column, include_destinations=True, title='Title', color_map='Greens',
                        shapefile='data/chicago_boundaries/chicago_boundaries.shp', spatial_index='community',
                        filename=None):
        """
        Args:
            column: Which column to plot.
            include_destinations: boolean, will plot circles for destinations if true.
            title: Figure title.
            color_map: See https://matplotlib.org/tutorials/colors/colormaps.html
            shapefile: filename of shapefile
            spatial_index: index of geospatial area in shapefile
            filename: file to write figure to.

        Raises:
            ModelNotAggregatedException: Model is not aggregated.
            UnexpectedPlotColumnException: User passes unexpected column.
        """
        if self.aggregated_results is None:
            raise ModelNotAggregatedException()
        if include_destinations:
            categories = self.focus_categories
        else:
            categories = None
        results_with_geometry = self._join_aggregated_data_with_boundaries(aggregated_results=self.aggregated_results,
                                                                           spatial_index=spatial_index,
                                                                           shapefile=shapefile)
        if column not in results_with_geometry.columns:
            raise UnexpectedPlotColumnException('Did not expect column argument: {}'.format(column))

        mpl.pyplot.close()

        mpl.pyplot.rcParams['axes.facecolor'] = '#cfcfd1'

        results_with_geometry.plot(column=column, cmap=color_map, edgecolor='black', linewidth=0.1)

        # add a scatter plot of the vendors over the chloropleth
        if categories is not None:
            available_colors = ['magenta', 'lime', 'red', 'black', 'orange', 'grey', 'yellow', 'brown', 'teal']
            # if we have too many categories of vendors, limit to using black dots
            if len(categories) > len(available_colors):
                monochrome = True
            else:
                monochrome = False
            color_keys = []
            max_dest_capacity = max(self.dests['capacity'])
            for category in categories:
                if monochrome:
                    color = 'black'
                else:
                    color = available_colors.pop(0)
                    patch = mpatches.Patch(color=color, label=category)
                    color_keys.append(patch)
                dest_subset = self.dests.loc[self.dests['category'] == category]
                mpl.pyplot.scatter(y=dest_subset['lat'], x=dest_subset['lon'], color=color, marker='o',
                                   s=50 * (dest_subset['capacity'] / max_dest_capacity), label=category)
                if not monochrome:
                    mpl.pyplot.legend(loc='best', handles=color_keys)

        mpl.pyplot.title(title)
        if filename is None:
            filename = self._get_output_filename(keyword='figure', extension='png',
                                            file_path='figures/')
        mpl.pyplot.savefig(filename, dpi=400)

        self.logger.info('Figure was saved to: {}'.format(filename))