Compare commits
6 Commits
2238d363e5
...
b0ccfc6309
Author | SHA1 | Date |
---|---|---|
Emi Simpson | b0ccfc6309 | |
Emi Simpson | 6f89dfa1c5 | |
Emi Simpson | 719de87ea9 | |
Emi Simpson | e6d8933ccf | |
Emi Simpson | 55fa7add0e | |
Emi Simpson | 79deebabf4 |
|
@ -5,6 +5,7 @@ from ir import BUILTIN_SUBSTITUTIONS, Expression, ReplHole, subst_all
|
||||||
from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable
|
from genir import json_to_ir, PatternParseProblem, BranchTypesDiffer, UndefinedVariable
|
||||||
from types_ import BUILTINS_CONTEXT, UnificationError
|
from types_ import BUILTINS_CONTEXT, UnificationError
|
||||||
from silly_thing import evaluate
|
from silly_thing import evaluate
|
||||||
|
from opt import optimize_to_fixpoint, all_optimizations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -16,7 +17,9 @@ def main():
|
||||||
case [_, file]:
|
case [_, file]:
|
||||||
# TODO handle this
|
# TODO handle this
|
||||||
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
|
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
|
||||||
result = evaluate(subst_all(BUILTIN_SUBSTITUTIONS, expr))
|
expr_with_builtins = subst_all(BUILTIN_SUBSTITUTIONS, expr)
|
||||||
|
expr_with_optimization = optimize_to_fixpoint(all_optimizations, expr_with_builtins)
|
||||||
|
result = evaluate(expr_with_optimization)
|
||||||
if isinstance(result, ReplHole):
|
if isinstance(result, ReplHole):
|
||||||
print(result.render())
|
print(result.render())
|
||||||
else:
|
else:
|
||||||
|
|
17
genir.py
17
genir.py
|
@ -2,15 +2,16 @@ from emis_funky_funktions import *
|
||||||
from typing import Sequence, Mapping
|
from typing import Sequence, Mapping
|
||||||
|
|
||||||
from pattern import lex_and_parse_pattern
|
from pattern import lex_and_parse_pattern
|
||||||
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
|
from ir import Expression, MonoFunc, Application, Int, Variable, LetBinding, ReplHole
|
||||||
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
|
from patterns import Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
|
||||||
|
from match_tree import MatchException
|
||||||
from types_ import *
|
from types_ import *
|
||||||
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import json
|
import json
|
||||||
|
|
||||||
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
|
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
|
||||||
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
|
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable | MatchException'
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PatternParseProblem:
|
class PatternParseProblem:
|
||||||
|
@ -170,8 +171,16 @@ def let_2_ir(
|
||||||
|
|
||||||
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
||||||
if isinstance(j, Mapping):
|
if isinstance(j, Mapping):
|
||||||
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
|
match branches_2_ir(tuple(j.items()), type_ctx):
|
||||||
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
|
case Ok((branches, ty, substitutions)):
|
||||||
|
match MonoFunc.from_match_function(branches):
|
||||||
|
case Ok(monofunc):
|
||||||
|
return Ok((monofunc, ty, substitutions))
|
||||||
|
case Err(e):
|
||||||
|
return Err(e)
|
||||||
|
case Err(e):
|
||||||
|
return Err(e)
|
||||||
|
raise Exception('Unreachable')
|
||||||
elif isinstance(j, str):
|
elif isinstance(j, str):
|
||||||
match type_ctx.instantiate(j):
|
match type_ctx.instantiate(j):
|
||||||
case Some(j_type):
|
case Some(j_type):
|
||||||
|
|
96
ir.py
96
ir.py
|
@ -8,8 +8,8 @@ from patterns import Pattern
|
||||||
import types_
|
import types_
|
||||||
|
|
||||||
|
|
||||||
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
|
Expression: TypeAlias = 'MonoFunc | Application | Int | Variable | Builtin | LetBinding | ReplHole | Switch'
|
||||||
Value: TypeAlias = 'Function | Int | Builtin | ReplHole'
|
Value: TypeAlias = 'MonoFunc | Int | Builtin | ReplHole'
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ReplHole:
|
class ReplHole:
|
||||||
|
@ -91,14 +91,15 @@ BUILTIN_SUBSTITUTIONS: Sequence[Tuple[str, Expression]] = (
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Function:
|
class MonoFunc:
|
||||||
forms: 'Sequence[Tuple[Pattern, Expression]]'
|
arg: str
|
||||||
|
body: Expression
|
||||||
|
|
||||||
def subst(self, expression: Expression, variable: str) -> Expression:
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
||||||
return Function([
|
if variable == self.arg:
|
||||||
(p, e if p.binds(variable) else e.subst(expression, variable))
|
return self
|
||||||
for (p, e) in self.forms
|
else:
|
||||||
])
|
return MonoFunc(self.arg, self.body.subst(expression, variable))
|
||||||
|
|
||||||
def is_value(self) -> bool:
|
def is_value(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -106,22 +107,42 @@ class Function:
|
||||||
def step(self) -> Option[Expression]:
|
def step(self) -> Option[Expression]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def eliminate(self, v: Expression) -> Result[Expression, MatchException]:
|
@staticmethod
|
||||||
match_trees = tuple(pattern.match_tree(EMPTY_STRUCT_PATH, LeafNode.from_value(bindings_to_lets(pattern.bindings(), v, body))) for (pattern, body) in self.forms)
|
def from_match_function(forms: 'Sequence[Tuple[Pattern, Expression]]') -> Result[Expression, MatchException]:
|
||||||
unified_match_tree = merge_all_trees(match_trees)
|
# In certain cases, starting a function with a full match tree may be unnecessary.
|
||||||
return compile_tree(unified_match_tree, v)
|
# Specifically, if there exists only one possible branch and that branch binds only
|
||||||
|
# one value and that value is equal to the whole entire input, rather than assigning
|
||||||
|
# that input to a new variable, we may simply use argument variable instead.
|
||||||
|
match forms:
|
||||||
|
case [(patt, body)]: # A single possible branch
|
||||||
|
match patt.bindings():
|
||||||
|
case []: # Binds nothing
|
||||||
|
return Ok(MonoFunc('_', body))
|
||||||
|
case [(var, [])]: # Binds a single variable to the entire input
|
||||||
|
return Ok(MonoFunc(var, body))
|
||||||
|
|
||||||
|
# If those special cases fail, we eliminate the pattern matching to produce a
|
||||||
|
# single body:
|
||||||
|
match_trees = tuple( # Construct a match tree for each possible branch
|
||||||
|
pattern.match_tree(
|
||||||
|
EMPTY_STRUCT_PATH,
|
||||||
|
LeafNode.from_value(bindings_to_lets(pattern.bindings(), Variable('$'), body))
|
||||||
|
)
|
||||||
|
for (pattern, body) in forms
|
||||||
|
)
|
||||||
|
unified_match_tree = merge_all_trees(match_trees) # Unify all the trees
|
||||||
|
compiled_tree = compile_tree(unified_match_tree, Variable('$')) # Turn each tree into IR
|
||||||
|
return compiled_tree <= p(MonoFunc, '$')
|
||||||
|
|
||||||
def try_apply(self, v: Expression) -> Option[Expression]:
|
def try_apply(self, v: Expression) -> Option[Expression]:
|
||||||
return hush(self.eliminate(v))
|
return Some(self.body.subst(v, self.arg))
|
||||||
|
|
||||||
def codegen_inner(self) -> str:
|
|
||||||
return unwrap_r(self.eliminate(Variable('$'))).codegen()
|
|
||||||
def codegen(self) -> str:
|
def codegen(self) -> str:
|
||||||
return '$=>' + self.codegen_inner()
|
return f'({self.arg}=>{self.body.codegen()})'
|
||||||
def codegen_named(self, name) -> str:
|
def codegen_named(self, name) -> str:
|
||||||
return f'function {name}($){{return {self.codegen_inner()}}}'
|
return f'(function {name}({self.arg}){{return {self.body.codegen()}}})'
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return '{ ' + ', '.join('"' + repr(repr(p))[1:-1] + '" : ' + repr(e) for (p, e) in self.forms) + ' }'
|
return f'{{{repr(self.arg)}: {repr(self.body)}}}'
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LetBinding:
|
class LetBinding:
|
||||||
|
@ -158,11 +179,14 @@ class LetBinding:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
|
return f'( "{self.lhs}", {repr(self.rhs)}, {repr(self.body)} )'
|
||||||
|
|
||||||
def codegen(self) -> str:
|
def codegen(self) -> str:
|
||||||
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, Function) else self.rhs.codegen()
|
rhs_cg = self.rhs.codegen_named(self.lhs) if isinstance(self.rhs, MonoFunc) else self.rhs.codegen()
|
||||||
return f'({self.lhs}={rhs_cg},{self.body.codegen()})'
|
if self.body == Variable(self.lhs):
|
||||||
|
return rhs_cg
|
||||||
|
else:
|
||||||
|
return f'({self.lhs}=>{self.body.codegen()})({rhs_cg})'
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Application:
|
class Application:
|
||||||
|
@ -187,7 +211,7 @@ class Application:
|
||||||
case Some(arg_stepped):
|
case Some(arg_stepped):
|
||||||
return Some(Application(self.first, arg_stepped))
|
return Some(Application(self.first, arg_stepped))
|
||||||
case None:
|
case None:
|
||||||
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
|
assert isinstance(self.first, MonoFunc) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
|
||||||
return self.first.try_apply(self.arg)
|
return self.first.try_apply(self.arg)
|
||||||
raise Exception('Unreachable')
|
raise Exception('Unreachable')
|
||||||
|
|
||||||
|
@ -195,7 +219,7 @@ class Application:
|
||||||
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
|
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
|
||||||
|
|
||||||
def codegen(self) -> str:
|
def codegen(self) -> str:
|
||||||
if isinstance(self.first, Function | Builtin) and self.arg.is_value():
|
if isinstance(self.first, MonoFunc | Builtin) and self.arg.is_value():
|
||||||
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
|
return unwrap_opt(self.first.try_apply(self.arg)).codegen()
|
||||||
else:
|
else:
|
||||||
match self.first:
|
match self.first:
|
||||||
|
@ -262,7 +286,7 @@ class Switch:
|
||||||
def subst(self, expression: Expression, variable: str) -> Expression:
|
def subst(self, expression: Expression, variable: str) -> Expression:
|
||||||
return Switch(
|
return Switch(
|
||||||
{i: e.subst(expression, variable) for i, e in self.branches.items()},
|
{i: e.subst(expression, variable) for i, e in self.branches.items()},
|
||||||
self.fallback,
|
self.fallback.subst(expression, variable),
|
||||||
self.switching_on.subst(expression, variable))
|
self.switching_on.subst(expression, variable))
|
||||||
|
|
||||||
def is_value(self) -> bool:
|
def is_value(self) -> bool:
|
||||||
|
@ -344,4 +368,26 @@ def subst_all(bindings: Sequence[Tuple[str, Expression]], body: Expression) -> E
|
||||||
return body
|
return body
|
||||||
case [(var, replacement), *rest]:
|
case [(var, replacement), *rest]:
|
||||||
return subst_all(rest, body.subst(replacement, var))
|
return subst_all(rest, body.subst(replacement, var))
|
||||||
raise Exception('Unreachable')
|
raise Exception('Unreachable')
|
||||||
|
|
||||||
|
def count_uses(variable: str, expression: Expression) -> int:
|
||||||
|
match expression:
|
||||||
|
case MonoFunc(arg, body):
|
||||||
|
return 0 if arg == variable else count_uses(variable, body)
|
||||||
|
case Application(first, arg):
|
||||||
|
return count_uses(variable, first) + count_uses(variable, arg)
|
||||||
|
case Int(_):
|
||||||
|
return 0
|
||||||
|
case Variable(name):
|
||||||
|
return 1 if name == variable else 0
|
||||||
|
case Builtin(_, _):
|
||||||
|
return 0
|
||||||
|
case LetBinding(lhs, rhs, body):
|
||||||
|
return count_uses(variable, rhs) + count_uses(variable, body)
|
||||||
|
case ReplHole(_, _):
|
||||||
|
return 0
|
||||||
|
case Switch(branches, fallback, switching_on):
|
||||||
|
return (
|
||||||
|
count_uses(variable, switching_on) +
|
||||||
|
count_uses(variable, fallback) +
|
||||||
|
sum(count_uses(variable, branch) for branch in branches.values()))
|
|
@ -0,0 +1,80 @@
|
||||||
|
from emis_funky_funktions import *
|
||||||
|
|
||||||
|
from ir import Expression, LetBinding, count_uses, MonoFunc, Application, LetBinding, Switch, Int, Variable, Builtin, ReplHole
|
||||||
|
|
||||||
|
from typing import TypeAlias, Collection
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
Optimization: TypeAlias = Callable[[Expression], Option[Expression]]
|
||||||
|
|
||||||
|
def eliminate_single_let(expr: Expression) -> Option[Expression]:
|
||||||
|
match expr:
|
||||||
|
case LetBinding(lhs, rhs, body):
|
||||||
|
if count_uses(lhs, body) <= 1:
|
||||||
|
# RHS is used at most once i nthe body
|
||||||
|
if count_uses(lhs, rhs) > 0:
|
||||||
|
# RHS is recursive
|
||||||
|
if not isinstance(body, Variable):
|
||||||
|
# But can still be pushed down
|
||||||
|
return Some(body.subst(LetBinding(lhs, rhs, Variable(lhs)), lhs))
|
||||||
|
else:
|
||||||
|
# And is already maximally pushed down
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# RHS is not recursive
|
||||||
|
return Some(body.subst(rhs, lhs))
|
||||||
|
else:
|
||||||
|
# RHS is used multiple times in the body
|
||||||
|
return None
|
||||||
|
case _:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply_opts(optimizations: Collection[Optimization], expression: Expression) -> Tuple[Expression, int]:
|
||||||
|
count: int
|
||||||
|
optimized_expr: Expression
|
||||||
|
match expression:
|
||||||
|
case MonoFunc(arg, body):
|
||||||
|
(optimized_body, count) = apply_opts(optimizations, body)
|
||||||
|
optimized_expr = MonoFunc(arg, optimized_body)
|
||||||
|
case Application(first, arg):
|
||||||
|
(optimized_first, count_first) = apply_opts(optimizations, first)
|
||||||
|
(optimized_arg , count_arg ) = apply_opts(optimizations, arg)
|
||||||
|
optimized_expr = Application(first, arg)
|
||||||
|
count = count_first + count_arg
|
||||||
|
case LetBinding(lhs, rhs, body):
|
||||||
|
(optimized_rhs , count_rhs ) = apply_opts(optimizations, rhs)
|
||||||
|
(optimized_body , count_body ) = apply_opts(optimizations, body)
|
||||||
|
optimized_expr = LetBinding(lhs, optimized_rhs, optimized_body)
|
||||||
|
count = count_rhs + count_body
|
||||||
|
case Switch(branches, fallback, switching_on):
|
||||||
|
branch_optimizations = tuple(
|
||||||
|
(branch_num, apply_opts(optimizations, branch_body))
|
||||||
|
for branch_num, branch_body in branches.items())
|
||||||
|
optimized_fallback, count_fallback = apply_opts(optimizations, fallback)
|
||||||
|
optimized_switching_on, count_switching_on = apply_opts(optimizations, switching_on)
|
||||||
|
count = sum(count_branch for _, (_, count_branch) in branch_optimizations) + count_fallback + count_switching_on
|
||||||
|
optimized_branches = {branch_num: optimized_branch for branch_num, (optimized_branch, _) in branch_optimizations}
|
||||||
|
optimized_expr = Switch(optimized_branches, optimized_fallback, optimized_switching_on)
|
||||||
|
case Int(_) | Variable(_) | Builtin(_, _) | ReplHole(_, _) as optimized_expr:
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
def fold_optimizations(acc: Tuple[Expression, int], optimization: Optimization) -> Tuple[Expression, int]:
|
||||||
|
acc_expression, acc_count = acc
|
||||||
|
match optimization(acc_expression):
|
||||||
|
case Some(expr_post_opt):
|
||||||
|
return (expr_post_opt, acc_count + 1)
|
||||||
|
case None:
|
||||||
|
return (acc_expression, acc_count)
|
||||||
|
raise Exception('Unreachable')
|
||||||
|
|
||||||
|
return reduce(fold_optimizations, optimizations, (optimized_expr, 0))
|
||||||
|
|
||||||
|
def optimize_to_fixpoint(optimizations: Collection[Optimization], expression: Expression) -> Expression:
|
||||||
|
match apply_opts(optimizations, expression):
|
||||||
|
case (optimal_expression, 0):
|
||||||
|
return optimal_expression
|
||||||
|
case (optimized_expression, _):
|
||||||
|
return optimize_to_fixpoint(optimizations, optimized_expression)
|
||||||
|
raise Exception('Unreachable')
|
||||||
|
|
||||||
|
all_optimizations: Sequence[Optimization] = (eliminate_single_let,)
|
Loading…
Reference in New Issue