Buggy typechecking

This commit is contained in:
Emi Simpson 2023-03-09 17:00:54 -05:00
parent 907bc5d505
commit 792674797d
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
5 changed files with 424 additions and 54 deletions

199
genir.py
View File

@ -1,28 +1,203 @@
from emis_funky_funktions import *
from typing import *
from typing import Sequence, Mapping
from pattern import lex_and_parse_pattern
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole, Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
from types_ import *
from functools import reduce
import json
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
def json_to_ir(j: JsonType) -> Expression:
@dataclass(frozen=True)
class PatternParseProblem:
pattern: str
@dataclass(frozen=True)
class PatternTypeMismatch:
pattern: Pattern
real_type: PolyType
required_type: PolyType
@dataclass(frozen=True)
class UndefinedVariable:
name: str
@dataclass(frozen=True)
class BranchTypesDiffer:
branch_1_pattern: str
branch_2_pattern: str
branch_1_type: MonoType
branch_2_type: MonoType
def pattern_type_bindings(pattern: Pattern, ctx: Context) -> Result[Tuple[MonoType, Sequence[Tuple[str, MonoType]]], PatternTypeMismatch]:
match pattern:
case NamePattern(name):
new_ty = ctx.new_type_var()
return Ok((new_ty, ((name, new_ty),)))
case IgnorePattern():
new_ty = ctx.new_type_var()
return Ok((new_ty, tuple()))
case IntPattern():
return Ok((IntTy(), tuple()))
case SPattern(pred):
match pattern_type_bindings(pred, ctx):
case Ok((desc_ty, desc_bindings)):
match desc_ty.unify(IntTy()):
case Ok(_):
return Ok((IntTy(), desc_bindings))
case Err(_):
return Err(PatternTypeMismatch(pred, ctx.generalize(desc_ty), PolyType([], IntTy())))
case Err(_) as e:
return e
raise Exception('Unreachable')
def branches_2_ir(
branches: Sequence[Tuple[str, JsonType]],
context: Context,
) -> Result[tuple[Sequence[tuple[Pattern, Expression]], MonoType, Sequence[Substitution]], SemanticError]:
# TOdO Affirm argument type with pattern
match branches:
case [(raw_pattern, raw_expr), *rest_branches_raw]:
# Parse the pattern
match lex_and_parse_pattern(raw_pattern):
case Ok(parsed_pattern):
# Type the pattern
match pattern_type_bindings(parsed_pattern, context):
case Ok((pattern_accept_type, bindings)):
# Parse the body
context_with_pattern_bindings = context.with_many_mono(bindings)
match json_to_ir(raw_expr, context.with_many_mono(bindings)):
case Ok(( parsed_expr, branch_ty, branch_subs )):
updated_ctx = context.subst_all(branch_subs)
# Compute the type of this branch
this_branch = FunctionTy(subst_all_monotype(pattern_accept_type, branch_subs), branch_ty)
# Parse the rest of the branches
match branches_2_ir(rest_branches_raw, updated_ctx):
case Ok(( rest_branches, rest_ty, rest_substs )):
# Unify this branch and the rest of the branches
match subst_all_monotype(this_branch, rest_substs).unify(rest_ty):
case Ok(branch_unif_subst):
return Ok((
((parsed_pattern, parsed_expr), *rest_branches),
subst_all_monotype(rest_ty, branch_unif_subst),
(*branch_subs, *rest_substs, *branch_unif_subst)
))
case Err(unif_err): # This branch's type disagrees with the rest of the branches
return Err(BranchTypesDiffer(raw_pattern, rest_branches_raw[0][0], branch_ty, rest_ty))
case Err(e): # Problem parsing one of the remaining branches
return Err(e)
case Err(_) as e: # Problem parsing expression
return e
case Err(_) as e: # Pattern type mismatch
return e
case Err(_): # Problem parsing pattern
return Err(PatternParseProblem(raw_pattern))
pass
case []:
return Ok(( tuple(), context.new_type_var(), tuple() ))
raise Exception('Unreachable, I hope') #god why can't mypy check this
def seq_nonmt_2_ir(
first_expr: Expression,
first_ty: MonoType,
right: Sequence[JsonType],
context: Context
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
match right:
case [arg, *rest]:
match json_to_ir(arg, context):
case Ok(( arg_expr, arg_ty, arg_substs )):
ret_ty = context.new_type_var()
match first_ty.unify(FunctionTy(arg_ty, ret_ty)):
case Ok(unification_substs):
updated_ctx = context.subst_all(arg_substs).subst_all(unification_substs)
return seq_nonmt_2_ir(
Application(first_expr, arg_expr),
subst_all_monotype(ret_ty, unification_substs),
rest,
updated_ctx
) <= (lambda expr__ty__subst:
( expr__ty__subst[0]
, expr__ty__subst[1]
, (*arg_substs, *unification_substs, *expr__ty__subst[2])
)
)
case Err(_) as e:
return e
case Err(_) as e:
return e
case []:
return Ok(( first_expr, first_ty, tuple() ))
raise Exception('Unreachable')
def let_2_ir(
lhs: str,
rhs: JsonType,
body: Sequence[JsonType],
context: Context,
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
# Parse the rhs
standin_ty = context.new_type_var()
match json_to_ir(rhs, context.with_mono(lhs, standin_ty)):
case Ok(( rhs_expr, rhs_ty, rhs_subst )):
# Unify the rhs type with the generated type of the rhs from earlier
match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty):
case Ok(recursion_substs):
updated_ctx = context.subst_all(rhs_subst).with_mono(lhs, rhs_ty).subst_all(recursion_substs)
# Parse the body
match json_to_ir(body, updated_ctx):
case Ok(( body_expr, body_ty, body_substs )):
return Ok(( LetBinding(lhs, rhs_expr, body_expr), body_ty, (*rhs_subst, *recursion_substs, *body_substs) ))
case Err(_) as e:
return e
case Err(_) as e:
return e
case Err(_) as e:
return e
raise Exception('Unreachable')
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
if isinstance(j, Mapping):
return Function(tuple(
#TODO handle parse errors
(unwrap_r(lex_and_parse_pattern(k)), json_to_ir(v))
for (k, v) in j.items()
))
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
elif isinstance(j, str):
return Variable(j)
match type_ctx.instantiate(j):
case Some(j_type):
return Ok(( Variable(j), j_type, tuple() ))
case None:
return Err(UndefinedVariable(j))
raise Exception('Unreachable')
elif isinstance(j, Sequence):
match j:
case [first, *rest]:
return Application(json_to_ir(first), [json_to_ir(a) for a in rest])
match json_to_ir(first, type_ctx):
case Ok(( fst, fst_ty, fst_substs)):
# Application
updated_ctx = type_ctx.subst_all(fst_substs)
return seq_nonmt_2_ir(fst, fst_ty, rest, updated_ctx) <= (lambda exp__ty__subs:
(exp__ty__subs[0], exp__ty__subs[1], (*fst_substs, *exp__ty__subs[2]))
)
case Err(UndefinedVariable(v)) if isinstance(first, str) and v == first:
# Let or String
match rest:
case [rhs, *body]:
# Let
return let_2_ir(first, rhs, body, type_ctx)
case []:
# String
raise Exception('TODO: Strings')
case []:
return ReplHole()
return Ok(( ReplHole(type_ctx), HoleTy(), tuple() ))
raise Exception('Unreachable')
else:
return Int(j)
return Ok(( Int(j), IntTy(), tuple() ))

54
ir.py
View File

@ -2,6 +2,8 @@ from emis_funky_funktions import *
from typing import Mapping, Sequence, Tuple, TypeAlias
import types_
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole'
Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern'
@ -120,10 +122,11 @@ class SPattern:
@dataclass(frozen=True)
class ReplHole:
bindings: Sequence[Tuple[str, Expression]] = tuple()
typ_bindings: types_.Context
val_bindings: Sequence[Tuple[str, Expression]] = tuple()
def subst(self, expression: Expression, variable: str) -> Expression:
return ReplHole((*self.bindings, (variable, expression)))
return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))
def is_value(self) -> bool:
return True
@ -239,58 +242,37 @@ class LetBinding:
)
def __repr__(self) -> str:
if isinstance(self.body, LetBinding) or isinstance(self.body, Application):
return f'L[ {repr(self.lhs)}, {repr(self.rhs)},{repr(self.body)[1:-1]}]'
else:
return f'L[ {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)}]'
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
@dataclass
class Application:
first: Expression
args: Sequence[Expression]
arg: Expression
def subst(self, expression: Expression, variable: str) -> Expression:
return Application(
self.first.subst(expression, variable),
[r.subst(expression, variable) for r in self.args]
self.arg.subst(expression, variable)
)
def is_value(self) -> bool:
return False
def step(self) -> Option[Expression]:
match self.args:
case []:
return Some(self.first)
case [a, *rest]:
match self.first.step():
case Some(first_stepped):
return Some(Application(first_stepped, self.args))
match self.first.step():
case Some(first_stepped):
return Some(Application(first_stepped, self.arg))
case None:
match self.arg.step():
case Some(arg_stepped):
return Some(Application(self.first, arg_stepped))
case None:
match a.step():
case Some(a_stepped):
return Some(Application(self.first, [a_stepped, *rest]))
case None:
if isinstance(self.first, Function) or isinstance(self.first, Builtin):
return map_opt(
lambda f_sub: Application(f_sub, rest),
self.first.try_apply(a)
)
elif isinstance(self.first, Variable):
lhs = self.first.name
rhs = a
body = rest
match body:
case []:
return Some(ReplHole())
case [body_first, *body_rest]:
return Some(LetBinding(lhs, rhs, Application(body_first, body_rest)))
else:
return None
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
return self.first.try_apply(self.arg)
raise Exception('Unreachable')
def __repr__(self) -> str:
return f'[ {repr(self.first)}, ' + ', '.join(repr(e) for e in self.args) + ' ]'
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
@dataclass
class Int:

View File

@ -1,6 +1,7 @@
from emis_funky_funktions import *
from genir import json_to_ir
from types_ import BUILTINS_CONTEXT
from silly_thing import repl, repl_expr
import json, sys
@ -8,7 +9,9 @@ import json, sys
def main():
match sys.argv:
case [_, file]:
repl_expr(json_to_ir(json.loads(open(sys.argv[1]).read())))
# TODO handle this
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
repl_expr(expr)
case _:
repl()

View File

@ -3,6 +3,7 @@ from typing import Collection, Sequence, TypeAlias
from ir import Expression, ReplHole, subst_all
from genir import json_to_ir
from types_ import BUILTINS_CONTEXT
import json
from dataclasses import dataclass
@ -37,17 +38,18 @@ def evaluate(expr: Expression) -> Expression:
raise AssertionError('Evaluate called on a value which cannot step:', expr)
raise Exception('Unreachable')
def repl_expr(expr: Expression, bindings: Sequence[Tuple[str, Expression]] = tuple()):
expr_subst = subst_all(bindings, expr)
def repl_expr(expr: Expression, bindings: ReplHole = ReplHole(BUILTINS_CONTEXT)):
expr_subst = subst_all(bindings.val_bindings, expr)
result = evaluate(expr_subst)
if isinstance(result, ReplHole):
print('Environment updated\n')
repl(result.bindings)
print(result.typ_bindings)
repl(result)
else:
print(result, end='\n\n')
repl(bindings)
def repl(bindings: Sequence[Tuple[str, Expression]] = tuple()):
def repl(bindings: ReplHole = ReplHole(BUILTINS_CONTEXT)):
print('Enter a JSON expression:')
try:
expr = input('-> ')
@ -59,7 +61,9 @@ def repl(bindings: Sequence[Tuple[str, Expression]] = tuple()):
except json.decoder.JSONDecodeError as e:
print(f'Bad json: ', e.args[0], end='\n\n')
return repl(bindings)
repl_expr(json_to_ir(ast), bindings)
# TODO handle this
new_expr, new_ty, substs = unwrap_r(json_to_ir(ast, bindings.typ_bindings))
repl_expr(new_expr, bindings)
if __name__ == '__main__':
import doctest

206
types_.py Normal file
View File

@ -0,0 +1,206 @@
from emis_funky_funktions import *
from typing import *
from dataclasses import dataclass, field
from functools import reduce
MonoType: TypeAlias = 'FunctionTy | HoleTy | IntTy | VarTy'
Substitution: TypeAlias = 'tuple[MonoType, VarTy]'
@dataclass(frozen=True)
class UnificationError:
ty1: MonoType
ty2: MonoType
@dataclass(frozen=True)
class VarTy:
id: int
name: str = field(compare=False)
def subst(self, replacement: MonoType, var: 'VarTy') -> MonoType:
if var.id == self.id:
return replacement
else:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) if id == self.id:
return Ok(tuple())
case _:
return Ok(((other, self),))
raise Exception('Unreachable')
def free_vars(self) -> FrozenSet['VarTy']:
return fset(self)
def __repr__(self) -> str:
return self.name
@dataclass(frozen=True)
class FunctionTy:
arg_type: MonoType
ret_type: MonoType
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return FunctionTy(
self.arg_type.subst(replacement, var),
self.ret_type.subst(replacement, var)
)
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case FunctionTy(other_arg, other_ret):
return self.arg_type.unify(other_arg) << (lambda arg_substs:
subst_all_monotype(self.ret_type, arg_substs)
.unify(
subst_all_monotype(other_ret, arg_substs)) << (lambda ret_substs:
Ok((*arg_substs, *ret_substs))))
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return self.arg_type.free_vars() | self.ret_type.free_vars()
def __repr__(self) -> str:
return f'({self.arg_type}) -> {self.ret_type}'
@dataclass(frozen=True)
class HoleTy:
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case HoleTy():
return Ok(tuple())
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return FSet()
def __repr__(self) -> str:
return '[]'
@dataclass(frozen=True)
class IntTy:
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
return self
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
match other:
case VarTy(id, str) as var:
return Ok(((self, var),))
case IntTy():
return Ok(tuple())
case _:
return Err(UnificationError(self, other))
def free_vars(self) -> FrozenSet[VarTy]:
return FSet()
def __repr__(self) -> str:
return 'int'
@dataclass(frozen=True)
class PolyType:
quantified_variables: Collection[VarTy]
monotype: MonoType
def subst(self, replacement: MonoType, var: VarTy) -> 'PolyType':
if var in self.quantified_variables:
return self
else:
return PolyType(
self.quantified_variables,
self.monotype.subst(replacement, var)
)
def instantiate(self, ctx: 'Context', name: str) -> MonoType:
match self.quantified_variables:
case []:
return self.monotype
case [alpha, *rest]:
return PolyType(
rest,
self.monotype.subst(ctx.new_type_var(name), alpha)
).instantiate(ctx, name)
raise Exception('Unreachable')
def free_vars(self) -> FrozenSet[VarTy]:
return self.monotype.free_vars() - FSet(self.quantified_variables)
def __repr__(self) -> str:
if len(self.quantified_variables):
qv = ' '.join(qv.name for qv in self.quantified_variables)
return f'forall {qv}. {self.monotype}'
else:
return repr(self.monotype)
UNIQUE_VAR_COUNTER:int = 0
@dataclass(frozen=True)
class Context:
variable_types: Mapping[str, PolyType]
def with_(self, var_name: str, var_type: PolyType) -> 'Context':
return Context({var_name: var_type, **self.variable_types})
def with_mono(self, var_name: str, var_type: MonoType) -> 'Context':
return Context({var_name: PolyType([], var_type), **self.variable_types})
def with_many_mono(self, bindings: Sequence[tuple[str, MonoType]]) -> 'Context':
match bindings:
case []:
return self
case [(var_name, var_type), *rest]:
return self.with_mono(var_name, var_type).with_many_mono(rest)
raise Exception('Unreachable')
def subst(self, replacement: MonoType, var: VarTy) -> 'Context':
return Context({
name: val
for (name, val) in self.variable_types.items()
})
def subst_all(self, subst: Sequence[Tuple[MonoType, VarTy]]) -> 'Context':
return reduce(lambda a, s: a.subst(*s), subst, self)
def new_type_var(self, name_prefix = 'T') -> VarTy:
global UNIQUE_VAR_COUNTER
UNIQUE_VAR_COUNTER += 1
return VarTy(UNIQUE_VAR_COUNTER, f'{name_prefix}${UNIQUE_VAR_COUNTER}')
def instantiate(self, name: str) -> Option[MonoType]:
if name in self.variable_types:
return Some(self.variable_types[name].instantiate(self, name))
else:
return None
def free_vars(self) -> FrozenSet[VarTy]:
return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars())
def generalize(self, mt: MonoType) -> PolyType:
return PolyType(mt.free_vars(), mt)
def __contains__(self, name: str) -> bool:
return name in self.variable_types
def __getitem__(self, i: str) -> PolyType:
return self.variable_types[i]
BUILTINS_CONTEXT: Context = (
Context({})
.with_mono('+', FunctionTy(IntTy(), FunctionTy(IntTy(), IntTy())))
.with_mono('S', FunctionTy(IntTy(), IntTy()))
)
def subst_all_monotype(m: MonoType, substs: Sequence[Tuple[MonoType, VarTy]]) -> MonoType:
return reduce(lambda a, s: a.subst(*s), substs, m)