from emis_funky_funktions import *

from typing import Collection, Mapping, Sequence, Tuple, TypeAlias
from functools import reduce
from match_tree import MatchTree, MatchException, StructurePath, LeafNode, merge_all_trees, IntNode, EMPTY_STRUCT_PATH, FAIL_NODE
from patterns import Pattern

import types_


Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'

@dataclass(frozen=True)
class ReplHole:
	typ_bindings: types_.Context
	val_bindings: Sequence[Tuple[str, Expression]] = tuple()

	def subst(self, expression: Expression, variable: str) -> Expression:
		return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))

	def is_value(self) -> bool:
		return True

	def step(self) -> Option[Expression]:
		return None

	def __repr__(self) -> str:
		return "[]"

	def codegen(self) -> str:
		return '[]'

	def render(self) -> str:
		return '\n'.join(
			f'const {var_name} = ({var_expr.codegen()});'
			for (var_name, var_expr) in self.val_bindings
			if var_name not in types_.BUILTINS_CONTEXT
		)

BuiltinBehavior: TypeAlias = 'Builtin.BB_PLUS_CONST | Builtin.BB_PLUS'

@dataclass(frozen=True)
class Builtin:
	behavior: BuiltinBehavior

	@dataclass(frozen=True)
	class BB_PLUS_CONST:
		amt: int
		def name(self) -> str:
			return f'{self.amt:+}'
		def js(self) -> str:
			return f'(x=>x{self.amt:+})'
		def run(self, e: Expression) -> Option[Expression]:
			return Some(Int(e.value + self.amt)) if isinstance(e, Int) else None

	@dataclass(frozen=True)
	class BB_PLUS:
		def name(self) -> str:
			return '+'
		def js(self) -> str:
			return '(x=>y=>x+y)'
		def run(self, e: Expression) -> Option[Expression]:
			return Some(Builtin(Builtin.BB_PLUS_CONST(e.value))) if isinstance(e, Int) else None

	def subst(self, expression: Expression, variable: str) -> Expression:
		return self

	def is_value(self) -> bool:
		return True

	def step(self) -> Option[Expression]:
		return None

	def try_apply(self, v: Expression) -> Option[Expression]:
		return self.behavior.run(v)

	def __repr__(self) -> str:
		return "'" + repr(self.behavior.name())[1:-1] + "'"

	def codegen(self) -> str:
		return self.behavior.js()

	PLUS: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS())
	S: 'Callable[[], Builtin]' = lambda: Builtin(Builtin.BB_PLUS_CONST(1))

BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
	('+', Builtin.PLUS()),
	('S', Builtin.S()),
)

@dataclass(frozen=True)
class MonoFunc:
	arg: str
	body: Expression

	def subst(self, expression: Expression, variable: str) -> Expression:
		if variable == self.arg:
			return self
		else:
			return MonoFunc(self.arg, self.body.subst(expression, variable))

	def is_value(self) -> bool:
		return True

	def step(self) -> Option[Expression]:
		return None

	@staticmethod
	def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
		# In certain cases, starting a function with a full match tree may be unnecessary.
		# Specifically, if there exists only one possible branch and that branch binds only
		# one value and that value is equal to the whole entire input, rather than assigning
		# that input to a new variable, we may simply use argument variable instead.
		match forms:
			case [(patt, body)]: # A single possible branch
				match patt.bindings():
					case []: # Binds nothing
						return Ok(MonoFunc('_', body))
					case [(var, [])]: # Binds a single variable to the entire input
						return Ok(MonoFunc(var, body))

		# If those special cases fail, we eliminate the pattern matching to produce a
		# single body:
		match_trees = tuple( # Construct a match tree for each possible branch
			pattern.match_tree(
				EMPTY_STRUCT_PATH,
				LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body))
			)
			for (pattern, body) in forms
		)
		unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
		compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR
		return compiled_tree <= p(MonoFunc, '$')

	def try_apply(self, v: Expression) -> Option[Expression]:
		return Some(self.body.subst(v, self.arg))

	def codegen(self) -> str:
		return f'({self.arg}=>{self.body.codegen()})'
	def codegen_named(self, name) -> str:
		return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
	def __repr__(self) -> str:
		return f'{{{repr(self.arg)}: {repr(self.body)}}}'

@dataclass
class LetBinding:
	lhs: str
	rhs: Expression
	body: Expression

	def subst(self, expression: Expression, variable: str) -> Expression:
		if self.lhs == variable:
			return self
		else:
			return LetBinding(
				self.lhs,
				self.rhs.subst(expression, variable),
				self.body.subst(expression, variable)
			)

	def is_value(self) -> bool:
		return False

	def step(self) -> Option[Expression]:
		if self.rhs.is_value():
			return Some(self.body.subst(
				self.rhs.subst(
					LetBinding(self.lhs, self.rhs, Variable(self.lhs)),
					self.lhs
				),
				self.lhs
			))
		else:
			return map_opt(lambda rhs_step:
				LetBinding(self.lhs, rhs_step, self.body),
				self.rhs.step()
			)

	def __repr__(self) -> str:
		return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'

	def codegen(self) -> str:
		rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
		if self.body == Variable(self.lhs):
			return rhs_cg
		else:
			return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'

@dataclass
class Application:
	first: Expression
	arg: Expression

	def subst(self, expression: Expression, variable: str) -> Expression:
		return Application(
			self.first.subst(expression, variable),
			self.arg.subst(expression, variable)
		)

	def is_value(self) -> bool:
		return False

	def step(self) -> Option[Expression]:
		match self.first.step():
			case Some(first_stepped):
				return Some(Application(first_stepped, self.arg))
			case None:
				match self.arg.step():
					case Some(arg_stepped):
						return Some(Application(self.first, arg_stepped))
					case None:
						assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
						return self.first.try_apply(self.arg)
		raise Exception('Unreachable')

	def __repr__(self) -> str:
		return f'[ {repr(self.first)}, {repr(self.arg)} ]'

	def codegen(self) -> str:
		if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
			return unwrap_opt(self.first.try_apply(self.arg)).codegen()
		else:
			match self.first:
				case Application(Builtin(Builtin.BB_PLUS), addend1):
					return f'({addend1.codegen()} + {self.arg.codegen()})'
				case Builtin(Builtin.BB_PLUS_CONST(n)):
					return f'({self.arg.codegen()}{n:+})'
			return f'({self.first.codegen()})({self.arg.codegen()})'

@dataclass
class Int:
	value: int

	def subst(self, expression: Expression, variable: str) -> Expression:
		return self

	def is_value(self) -> bool:
		return True

	def step(self) -> Option[Expression]:
		return None

	def __repr__(self) -> str:
		return str(self.value)

	def codegen(self) -> str:
		return str(self.value)

@dataclass
class Variable:
	name: str

	def subst(self, expression: Expression, variable: str) -> Expression:
		if variable == self.name:
			return expression
		else:
			return self

	def is_value(self) -> bool:
		return False

	def step(self) -> Option[Expression]:
		match self.name:
			case '+':
				return Some(Builtin.PLUS())
			case 'S':
				return Some(Builtin.S())
		return None

	def __repr__(self) -> str:
		return '"' + repr(self.name)[1:-1] + '"'

	def codegen(self) -> str:
		return self.name

@dataclass
class Switch:
	branches: Mapping[int, Expression]
	fallback: Expression
	switching_on: Expression

	def subst(self, expression: Expression, variable: str) -> Expression:
		return Switch(
			{i: e.subst(expression, variable) for i, e in self.branches.items()},
			self.fallback.subst(expression, variable),
			self.switching_on.subst(expression, variable))

	def is_value(self) -> bool:
		return False

	def step(self) -> Option[Expression]:
		match self.switching_on.step():
			case Some(switch_expr_stepped):
				return Some(Switch(self.branches, self.fallback, switch_expr_stepped))
			case None:
				match self.switching_on:
					case Int(n):
						if n in self.branches:
							return Some(self.branches[n])
						else:
							return Some(self.fallback)
				raise Exception('Attempted to switch on non-integer value')
		raise Exception('Unreachable')

	def __repr__(self) -> str:
		return '{ ' + ', '.join(f'{n}: ' + repr(e) for (n, e) in self.branches.items()) + f', _: {repr(self.fallback)}' + ' }'

	def codegen(self) -> str:
		switching_on_code = self.switching_on.codegen()
		return ':'.join(
			f'{switching_on_code}=={val}?({branch.codegen()})'
			for val, branch in self.branches.items()
		) + f':{self.fallback.codegen()}'

def compile_tree(tree: 'MatchTree[Expression]', match_against: Expression) -> Result[Expression, MatchException]:
	match tree:
		case LeafNode([match]):
			return Ok(match)
		case LeafNode([]):
			return Err(MatchException.Incomplete)
		case LeafNode([a, b, *rest]):
			return Err(MatchException.Ambiguous)
		case IntNode(location, specific_trees, fallback_tree):
			access_location = location_to_ir(location)(match_against)
			match sequence(tuple(compile_tree(tree, match_against) for tree in specific_trees.values())):
				case Err(e):
					return Err(e)
				case Ok(exprs):
					match compile_tree(fallback_tree, match_against):
						case Err(e):
							return Err(e)
						case Ok(fallback):
							return Ok(Switch(dict(zip(specific_trees.keys(), exprs)), fallback, match_against))
	raise Exception('Unreachable')

def location_to_ir(location: StructurePath) -> Callable[[Expression], Expression]:
	def access_location(part: int) -> Callable[[Expression], Expression]:
		def remove(expr: Expression) -> Expression:
			return Application(Builtin(Builtin.BB_PLUS_CONST(-1)), expr)
		def access_location_prime(expr: Expression) -> Expression:
			if part < 1:
				return remove(expr)
			else:
				raise AssertionError('A!')
		return access_location_prime
	match location:
		case []:
			return lambda o: o
		case [part, *rest_location]:
			return c(location_to_ir(StructurePath(rest_location)), access_location(part))
	raise Exception('Unreachable')

def bindings_to_lets(bindings: Collection[Tuple[str, StructurePath]], deconstructing_term: Expression, body_expr: Expression) -> Expression:
	match bindings:
		case []:
			return body_expr
		case [(binding_name, location), *rest]:
			return LetBinding(binding_name, location_to_ir(location)(deconstructing_term), bindings_to_lets(rest, deconstructing_term, body_expr))
	raise Exception('Unreachable')

def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> Expression:
	match bindings:
		case []:
			return body
		case [(var, replacement), *rest]:
			return subst_all(rest, body.subst(replacement, var))
	raise Exception('Unreachable')

def count_uses(variable: str, expression: Expression) -> int:
	match expression:
		case MonoFunc(arg, body):
			return 0 if arg == variable else count_uses(variable, body)
		case Application(first, arg):
			return count_uses(variable, first) + count_uses(variable, arg)
		case Int(_):
			return 0
		case Variable(name):
			return 1 if name == variable else 0
		case Builtin(_, _):
			return 0
		case LetBinding(lhs, rhs, body):
			return count_uses(variable, rhs) + count_uses(variable, body)
		case ReplHole(_, _):
			return 0
		case Switch(branches, fallback, switching_on):
			return (
				count_uses(variable, switching_on) +
				count_uses(variable, fallback) +
				sum(count_uses(variable, branch) for branch in branches.values()))