import random
import networkx as nx
from collections import Counter
from inspect import signature

from .types import NodeSpec, EdgeSpec
from .generate_graph import StationProperties, LineProperties

from typing import List, Dict

import logging
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------
# Executable syntax tree to represent and calculate answers
# --------------------------------------------------------------------------

class FunctionalOperator(object):
	def __init__(self, *args):
		self.args = args

	def __call__(self, graph):
		"""Execute this whole program to get an answer"""

		def ex(item):
			if isinstance(item, FunctionalOperator):
				return item(graph)
			else:
				return item

		vals = [ex(i) for i in self.args]
		try:
			return self.op(graph, *vals)
		except Exception as ex:
			logger.debug("Failed to execute operation {}({}) {}".format(type(self).__name__, vals, ex))
			raise ex

	def op(self, *args):
		"""
		Perform this individual operation

		Operations should raise ValueError if it is not possible to generate
		a valid answer, but no error has occured. This exception will be silently
		swallowed.

		"""
		raise NotImplementedError()

	def stripped(self):
		"""Represent this program for export"""

		def ex(item):
			try:
				return item.stripped()
			except AttributeError:

				# YAML export will freak out if it hits a lambda, so symbolically replace it
				if callable(item):
					sig = signature(item)
					args = [LambdaArg(i) for i in sig.parameters]
					return Lambda(item(*args)).stripped()
				else:
					return item

		k = [ex(i) for i in self.args]
		
		r = {}
		r[type(self).__name__] = k
		return r

def macro(f):
	return f


# --------------------------------------------------------------------------
#  Noun operations
# --------------------------------------------------------------------------

class Station(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Station(random.choice(list(graph.nodes.values())))

class FakeStationName(FunctionalOperator):
	@classmethod
	def get(self, graph):
		# This needs generalised later
		actual_station_names = {str(j.name()) for j in graph.nodes.values()}
		max_stn = len(graph.nodes) * 2
		nonexistent_stations = [i for i in range(max_stn) if str(i) not in actual_station_names]
		return FakeStationName(random.choice(nonexistent_stations))

class StationPropertyName(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return StationPropertyName(random.choice(StationProperties.keys()))

class StationProperty(FunctionalOperator):
	@classmethod
	def get(self, graph):
		key = random.choice(list(StationProperties.keys()))
		return StationProperty(key, StationProperties[key])

class Line(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Line(random.choice(list(graph.lines.values())))

class Architecture(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Architecture(random.choice(StationProperties["architecture"]))

class Size(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Size(random.choice(StationProperties["size"]))

class Music(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Music(random.choice(StationProperties["music"]))

class Cleanliness(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Cleanliness(random.choice(StationProperties["cleanliness"]))

class Boolean(FunctionalOperator):
	@classmethod
	def get(self, graph):
		return Boolean(random.choice([True, False]))

# --------------------------------------------------------------------------
# General operations
# --------------------------------------------------------------------------

class Const(FunctionalOperator):
	def op(self, graph, a):
		return a

class Lambda(FunctionalOperator):
	def op(self, graph, a):
		return a

class LambdaArg(FunctionalOperator):
	def op(self, graph, a):
		return a

class Pluck(FunctionalOperator):
	def op(self, graph, a, b):
		return [i[b] for i in a]

class Pick(FunctionalOperator):
	def op(self, graph, a, b):
		return a[b]

class Equal(FunctionalOperator):
	def op(self, graph, a, b):
		return a == b



# --------------------------------------------------------------------------
# Graph operations
# --------------------------------------------------------------------------

class AllEdges(FunctionalOperator):
	def op(self, graph):
		return graph.edges

class AllNodes(FunctionalOperator):
	def op(self, graph):
		return graph.nodes.values()

class Edges(FunctionalOperator):
	def op(self, graph, a):
		if isinstance(a, NodeSpec):
			return [edge[2]['attr_dict'] for edge in graph.gnx.edges([a["id"]], data=True)]
		else:
			return [
				edge[2]['attr_dict'] 
				for node in a
				for edge in graph.gnx.edges([node["id"]], data=True) 
			]

class Nodes(FunctionalOperator):
	def op(self, graph, edges:EdgeSpec):
		n = []
		for i in edges:
			n.append(graph.nodes[i["station1"]])
			n.append(graph.nodes[i["station2"]])
		return list(set(n))


def ids_to_nodes(graph, ids):
	return [graph.nodes[i] for i in ids]

class ShortestPath(FunctionalOperator):
	def op(self, graph, a:NodeSpec, b:NodeSpec, fallback):
		try:
			return ids_to_nodes(graph, nx.shortest_path(graph.gnx, a["id"], b["id"]))
		except nx.exception.NetworkXNoPath:
			return fallback

class ShortestPathOnlyUsing(FunctionalOperator):
	def op(self, graph, a:NodeSpec, b:NodeSpec, only_using_nodes:List[NodeSpec], fallback):
		try:
			induced_subgraph = nx.induced_subgraph(graph.gnx, [i["id"] for i in only_using_nodes + [a,b]])
			return ids_to_nodes(graph, nx.shortest_path(induced_subgraph, a["id"], b["id"]))
		except nx.exception.NetworkXNoPath:
			return fallback

class Paths(FunctionalOperator):
	def op(self, graph, a:NodeSpec, b:NodeSpec):
		return [ids_to_nodes(graph, i) for i in nx.all_simple_paths(graph.gnx, a["id"], b["id"])]


class HasCycle(FunctionalOperator):
	def op(self, graph, a:NodeSpec):

		# Would all_simple_paths also solve this for us?

		def canonical_edge(e):
			return (frozenset(e[:2]), e[2]["attr_dict"]["line_id"])

		def dfs_unidirected_cycle(head_id, path_nodes=frozenset(), path_edges=frozenset(), indent=""):
			for e in graph.gnx.edges([head_id], data=True):
				assert e[0] == head_id
				assert head_id in path_nodes

				from_id = head_id
				to_id = e[1]

				# Nothing new
				if canonical_edge(e) in path_edges:
					continue

				# If we've returned home
				if to_id == a["id"]:
					return True

				# Nothing new
				if to_id in path_nodes:
					continue
			
				ir = dfs_unidirected_cycle(
					to_id, 
					path_nodes | set([to_id]), 
					path_edges | set([canonical_edge(e)]),
					indent=indent+"  ",
				)

				if ir:
					return True


			return False
		return dfs_unidirected_cycle(a["id"], frozenset([a["id"]]))

class FilterAdjacent(FunctionalOperator): 
	def op(self, graph, a:List, b:List):
		r = []
		for i in a:
			for j in b:
				ns = graph.gnx.neighbors(i["id"])
				if j["id"] in ns:
					r.append([i,j])
		return r

class Neighbors(FunctionalOperator):
	def op(self, graph, station:NodeSpec):
		return ids_to_nodes(graph, graph.gnx.neighbors(station["id"]))

class WithinHops(FunctionalOperator):
	def op(self, graph, station:NodeSpec, hops:int):
		rs = set()
		tips = set([station])
		for i in range(hops):
			next_tips = set()
			for j in tips:
				next_tips |= set(ids_to_nodes(graph, graph.gnx.neighbors(j["id"])))

			rs |= tips
			tips = next_tips - rs

		rs |= next_tips
		rs.remove(station)
		return list(rs)





class FilterHasPathTo(FunctionalOperator):
	def op(self, graph, a:List, b:NodeSpec):
		return [i for i in a if nx.has_path(graph.gnx, i["id"], b["id"])]


# --------------------------------------------------------------------------
# List operators
# --------------------------------------------------------------------------

class NotEmpty(FunctionalOperator):
	def op(self, graph, l):
		return len(l) > 0

class Count(FunctionalOperator):
	def op(self, graph, l):
		return len(l)

class CountIfEqual(FunctionalOperator):
	def op(self, graph, l, t):
		return len([i for i in l if i == t])

class Mode(FunctionalOperator):
	def op(self, graph, l):

		if len(l) == 0:
			raise ValueError("Cannot find mode of empty sequence")

		c = Counter(l)
		most = c.most_common(2)

		# Only one unique value in l
		if len(most) == 1:
			return most[0][0]

		# If the most common occurs more than any other
		if most[0][1] > most[1][1]:
			return most[0][0]

		raise ValueError("No unique mode")

class Unique(FunctionalOperator):
	def op(self, graph, l):
		return list(set(l))

class SlidingPairs(FunctionalOperator):
	def op(self, graph, l):
		return [(l[i], l[i+1]) for i in range(len(l)-1)]


@macro
def GetLines(a):
	return Unique(Pluck(Edges(a), "line_name"))

@macro
def Adjacent(a, b):
	return Equal(Count(ShortestPath(a, b, [])), 2)

@macro
def CountNodesBetween(a):
	return Subtract(Count(a), 2)

class HasIntersection(FunctionalOperator):
	def op(self, graph, a, b):
		for i in a:
			if i in b:
				return True
		return False

class Intersection(FunctionalOperator):
	def op(self, graph, a, b):
		return list(set(a) & set(b))

class Filter(FunctionalOperator):
	def op(self, graph, a:List, b, c):
		return [i for i in a if i[b] == c]

class Without(FunctionalOperator):
	def op(self, graph, a:List, b, c):
		return [i for i in a if i[b] != c]

class UnpackUnitList(FunctionalOperator):
	"""This operator will raise if the given list is not length 1 - this is used as a guard against generating ambiguous questions"""
	def op(self, graph, l:List):
		if len(l) == 1:
			return l[0]
		else:
			raise ValueError(f"List is length {len(l)}, expected 1")

class Sample(FunctionalOperator):
	def op(self, graph, l:List, n:int):
		if len(l) < n:
			raise ValueError(f"Cannot sample {n} items from list of length {len(l)}")
		else:
			return random.choices(l, k=n)

class First(FunctionalOperator):
	def op(self, graph, l:List):
		return l[0]

class MinBy(FunctionalOperator):
	def op(self, graph, a, b):
		if len(a) == 0:
			raise ValueError("Cannot perform MinBy on empty list")
		return min(a, key=lambda i: b(i)(graph))


# --------------------------------------------------------------------------
# Numerical operations
# --------------------------------------------------------------------------


class Subtract(FunctionalOperator):
	def op(self, graph, a, b):
		return a - b


class Round(FunctionalOperator):
	def op(self, graph, a):
		try:
			return [round(float(i)) for i in a]
		except TypeError:
			return round(float(a))