Buggy typechecking
This commit is contained in:
parent
907bc5d505
commit
792674797d
199
genir.py
199
genir.py
|
@ -1,28 +1,203 @@
|
|||
from emis_funky_funktions import *
|
||||
from typing import *
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
from pattern import lex_and_parse_pattern
|
||||
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole
|
||||
from ir import Expression, Function, Application, Int, Variable, LetBinding, ReplHole, Pattern, NamePattern, IgnorePattern, IntPattern, SPattern
|
||||
from types_ import *
|
||||
|
||||
from functools import reduce
|
||||
import json
|
||||
|
||||
JsonType: TypeAlias = 'Mapping[str, JsonType] | Sequence[JsonType] | int | str'
|
||||
SemanticError: TypeAlias = 'PatternParseProblem | UnificationError | BranchTypesDiffer | UndefinedVariable'
|
||||
|
||||
def json_to_ir(j: JsonType) -> Expression:
|
||||
@dataclass(frozen=True)
|
||||
class PatternParseProblem:
|
||||
pattern: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PatternTypeMismatch:
|
||||
pattern: Pattern
|
||||
real_type: PolyType
|
||||
required_type: PolyType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UndefinedVariable:
|
||||
name: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BranchTypesDiffer:
|
||||
branch_1_pattern: str
|
||||
branch_2_pattern: str
|
||||
branch_1_type: MonoType
|
||||
branch_2_type: MonoType
|
||||
|
||||
def pattern_type_bindings(pattern: Pattern, ctx: Context) -> Result[Tuple[MonoType, Sequence[Tuple[str, MonoType]]], PatternTypeMismatch]:
|
||||
match pattern:
|
||||
case NamePattern(name):
|
||||
new_ty = ctx.new_type_var()
|
||||
return Ok((new_ty, ((name, new_ty),)))
|
||||
case IgnorePattern():
|
||||
new_ty = ctx.new_type_var()
|
||||
return Ok((new_ty, tuple()))
|
||||
case IntPattern():
|
||||
return Ok((IntTy(), tuple()))
|
||||
case SPattern(pred):
|
||||
match pattern_type_bindings(pred, ctx):
|
||||
case Ok((desc_ty, desc_bindings)):
|
||||
match desc_ty.unify(IntTy()):
|
||||
case Ok(_):
|
||||
return Ok((IntTy(), desc_bindings))
|
||||
case Err(_):
|
||||
return Err(PatternTypeMismatch(pred, ctx.generalize(desc_ty), PolyType([], IntTy())))
|
||||
case Err(_) as e:
|
||||
return e
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def branches_2_ir(
|
||||
branches: Sequence[Tuple[str, JsonType]],
|
||||
context: Context,
|
||||
) -> Result[tuple[Sequence[tuple[Pattern, Expression]], MonoType, Sequence[Substitution]], SemanticError]:
|
||||
# TOdO Affirm argument type with pattern
|
||||
match branches:
|
||||
case [(raw_pattern, raw_expr), *rest_branches_raw]:
|
||||
# Parse the pattern
|
||||
match lex_and_parse_pattern(raw_pattern):
|
||||
case Ok(parsed_pattern):
|
||||
|
||||
# Type the pattern
|
||||
match pattern_type_bindings(parsed_pattern, context):
|
||||
case Ok((pattern_accept_type, bindings)):
|
||||
|
||||
# Parse the body
|
||||
context_with_pattern_bindings = context.with_many_mono(bindings)
|
||||
match json_to_ir(raw_expr, context.with_many_mono(bindings)):
|
||||
case Ok(( parsed_expr, branch_ty, branch_subs )):
|
||||
updated_ctx = context.subst_all(branch_subs)
|
||||
|
||||
# Compute the type of this branch
|
||||
this_branch = FunctionTy(subst_all_monotype(pattern_accept_type, branch_subs), branch_ty)
|
||||
|
||||
# Parse the rest of the branches
|
||||
match branches_2_ir(rest_branches_raw, updated_ctx):
|
||||
case Ok(( rest_branches, rest_ty, rest_substs )):
|
||||
|
||||
# Unify this branch and the rest of the branches
|
||||
match subst_all_monotype(this_branch, rest_substs).unify(rest_ty):
|
||||
case Ok(branch_unif_subst):
|
||||
return Ok((
|
||||
((parsed_pattern, parsed_expr), *rest_branches),
|
||||
subst_all_monotype(rest_ty, branch_unif_subst),
|
||||
(*branch_subs, *rest_substs, *branch_unif_subst)
|
||||
))
|
||||
case Err(unif_err): # This branch's type disagrees with the rest of the branches
|
||||
return Err(BranchTypesDiffer(raw_pattern, rest_branches_raw[0][0], branch_ty, rest_ty))
|
||||
case Err(e): # Problem parsing one of the remaining branches
|
||||
return Err(e)
|
||||
case Err(_) as e: # Problem parsing expression
|
||||
return e
|
||||
case Err(_) as e: # Pattern type mismatch
|
||||
return e
|
||||
case Err(_): # Problem parsing pattern
|
||||
return Err(PatternParseProblem(raw_pattern))
|
||||
pass
|
||||
case []:
|
||||
return Ok(( tuple(), context.new_type_var(), tuple() ))
|
||||
raise Exception('Unreachable, I hope') #god why can't mypy check this
|
||||
|
||||
def seq_nonmt_2_ir(
|
||||
first_expr: Expression,
|
||||
first_ty: MonoType,
|
||||
right: Sequence[JsonType],
|
||||
context: Context
|
||||
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
||||
match right:
|
||||
case [arg, *rest]:
|
||||
match json_to_ir(arg, context):
|
||||
case Ok(( arg_expr, arg_ty, arg_substs )):
|
||||
ret_ty = context.new_type_var()
|
||||
match first_ty.unify(FunctionTy(arg_ty, ret_ty)):
|
||||
case Ok(unification_substs):
|
||||
updated_ctx = context.subst_all(arg_substs).subst_all(unification_substs)
|
||||
return seq_nonmt_2_ir(
|
||||
Application(first_expr, arg_expr),
|
||||
subst_all_monotype(ret_ty, unification_substs),
|
||||
rest,
|
||||
updated_ctx
|
||||
) <= (lambda expr__ty__subst:
|
||||
( expr__ty__subst[0]
|
||||
, expr__ty__subst[1]
|
||||
, (*arg_substs, *unification_substs, *expr__ty__subst[2])
|
||||
)
|
||||
)
|
||||
case Err(_) as e:
|
||||
return e
|
||||
case Err(_) as e:
|
||||
return e
|
||||
case []:
|
||||
return Ok(( first_expr, first_ty, tuple() ))
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def let_2_ir(
|
||||
lhs: str,
|
||||
rhs: JsonType,
|
||||
body: Sequence[JsonType],
|
||||
context: Context,
|
||||
) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
||||
# Parse the rhs
|
||||
standin_ty = context.new_type_var()
|
||||
match json_to_ir(rhs, context.with_mono(lhs, standin_ty)):
|
||||
case Ok(( rhs_expr, rhs_ty, rhs_subst )):
|
||||
|
||||
# Unify the rhs type with the generated type of the rhs from earlier
|
||||
match subst_all_monotype(standin_ty, rhs_subst).unify(rhs_ty):
|
||||
case Ok(recursion_substs):
|
||||
updated_ctx = context.subst_all(rhs_subst).with_mono(lhs, rhs_ty).subst_all(recursion_substs)
|
||||
|
||||
# Parse the body
|
||||
match json_to_ir(body, updated_ctx):
|
||||
case Ok(( body_expr, body_ty, body_substs )):
|
||||
return Ok(( LetBinding(lhs, rhs_expr, body_expr), body_ty, (*rhs_subst, *recursion_substs, *body_substs) ))
|
||||
case Err(_) as e:
|
||||
return e
|
||||
case Err(_) as e:
|
||||
return e
|
||||
case Err(_) as e:
|
||||
return e
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def json_to_ir(j: JsonType, type_ctx: Context) -> Result[tuple[Expression, MonoType, Sequence[Substitution]], SemanticError]:
|
||||
if isinstance(j, Mapping):
|
||||
return Function(tuple(
|
||||
#TODO handle parse errors
|
||||
(unwrap_r(lex_and_parse_pattern(k)), json_to_ir(v))
|
||||
for (k, v) in j.items()
|
||||
))
|
||||
return branches_2_ir(tuple(j.items()), type_ctx) <= (lambda ir_ty_subst:
|
||||
(Function(ir_ty_subst[0]), *ir_ty_subst[1:]))
|
||||
elif isinstance(j, str):
|
||||
return Variable(j)
|
||||
match type_ctx.instantiate(j):
|
||||
case Some(j_type):
|
||||
return Ok(( Variable(j), j_type, tuple() ))
|
||||
case None:
|
||||
return Err(UndefinedVariable(j))
|
||||
raise Exception('Unreachable')
|
||||
elif isinstance(j, Sequence):
|
||||
match j:
|
||||
case [first, *rest]:
|
||||
return Application(json_to_ir(first), [json_to_ir(a) for a in rest])
|
||||
match json_to_ir(first, type_ctx):
|
||||
case Ok(( fst, fst_ty, fst_substs)):
|
||||
# Application
|
||||
updated_ctx = type_ctx.subst_all(fst_substs)
|
||||
return seq_nonmt_2_ir(fst, fst_ty, rest, updated_ctx) <= (lambda exp__ty__subs:
|
||||
(exp__ty__subs[0], exp__ty__subs[1], (*fst_substs, *exp__ty__subs[2]))
|
||||
)
|
||||
case Err(UndefinedVariable(v)) if isinstance(first, str) and v == first:
|
||||
# Let or String
|
||||
match rest:
|
||||
case [rhs, *body]:
|
||||
# Let
|
||||
return let_2_ir(first, rhs, body, type_ctx)
|
||||
case []:
|
||||
# String
|
||||
raise Exception('TODO: Strings')
|
||||
case []:
|
||||
return ReplHole()
|
||||
return Ok(( ReplHole(type_ctx), HoleTy(), tuple() ))
|
||||
raise Exception('Unreachable')
|
||||
else:
|
||||
return Int(j)
|
||||
return Ok(( Int(j), IntTy(), tuple() ))
|
54
ir.py
54
ir.py
|
@ -2,6 +2,8 @@ from emis_funky_funktions import *
|
|||
|
||||
from typing import Mapping, Sequence, Tuple, TypeAlias
|
||||
|
||||
import types_
|
||||
|
||||
|
||||
Expression: TypeAlias = 'Function | Application | Int | Variable | Builtin | LetBinding | ReplHole'
|
||||
Pattern: TypeAlias = 'NamePattern | IntPattern | SPattern | IgnorePattern'
|
||||
|
@ -120,10 +122,11 @@ class SPattern:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ReplHole:
|
||||
bindings: Sequence[Tuple[str, Expression]] = tuple()
|
||||
typ_bindings: types_.Context
|
||||
val_bindings: Sequence[Tuple[str, Expression]] = tuple()
|
||||
|
||||
def subst(self, expression: Expression, variable: str) -> Expression:
|
||||
return ReplHole((*self.bindings, (variable, expression)))
|
||||
return ReplHole(self.typ_bindings, (*self.val_bindings, (variable, expression)))
|
||||
|
||||
def is_value(self) -> bool:
|
||||
return True
|
||||
|
@ -239,58 +242,37 @@ class LetBinding:
|
|||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self.body, LetBinding) or isinstance(self.body, Application):
|
||||
return f'L[ {repr(self.lhs)}, {repr(self.rhs)},{repr(self.body)[1:-1]}]'
|
||||
else:
|
||||
return f'L[ {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)}]'
|
||||
return f'( {repr(self.lhs)}, {repr(self.rhs)}, {repr(self.body)} )'
|
||||
|
||||
@dataclass
|
||||
class Application:
|
||||
first: Expression
|
||||
args: Sequence[Expression]
|
||||
arg: Expression
|
||||
|
||||
def subst(self, expression: Expression, variable: str) -> Expression:
|
||||
return Application(
|
||||
self.first.subst(expression, variable),
|
||||
[r.subst(expression, variable) for r in self.args]
|
||||
self.arg.subst(expression, variable)
|
||||
)
|
||||
|
||||
def is_value(self) -> bool:
|
||||
return False
|
||||
|
||||
def step(self) -> Option[Expression]:
|
||||
match self.args:
|
||||
case []:
|
||||
return Some(self.first)
|
||||
case [a, *rest]:
|
||||
match self.first.step():
|
||||
case Some(first_stepped):
|
||||
return Some(Application(first_stepped, self.args))
|
||||
match self.first.step():
|
||||
case Some(first_stepped):
|
||||
return Some(Application(first_stepped, self.arg))
|
||||
case None:
|
||||
match self.arg.step():
|
||||
case Some(arg_stepped):
|
||||
return Some(Application(self.first, arg_stepped))
|
||||
case None:
|
||||
match a.step():
|
||||
case Some(a_stepped):
|
||||
return Some(Application(self.first, [a_stepped, *rest]))
|
||||
case None:
|
||||
if isinstance(self.first, Function) or isinstance(self.first, Builtin):
|
||||
return map_opt(
|
||||
lambda f_sub: Application(f_sub, rest),
|
||||
self.first.try_apply(a)
|
||||
)
|
||||
elif isinstance(self.first, Variable):
|
||||
lhs = self.first.name
|
||||
rhs = a
|
||||
body = rest
|
||||
match body:
|
||||
case []:
|
||||
return Some(ReplHole())
|
||||
case [body_first, *body_rest]:
|
||||
return Some(LetBinding(lhs, rhs, Application(body_first, body_rest)))
|
||||
else:
|
||||
return None
|
||||
assert isinstance(self.first, Function) or isinstance(self.first, Builtin), "Type checking failed to produce valid IR, or preservation of types failed"
|
||||
return self.first.try_apply(self.arg)
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'[ {repr(self.first)}, ' + ', '.join(repr(e) for e in self.args) + ' ]'
|
||||
return f'[ {repr(self.first)}, {repr(self.arg)} ]'
|
||||
|
||||
@dataclass
|
||||
class Int:
|
||||
|
|
5
main.py
5
main.py
|
@ -1,6 +1,7 @@
|
|||
from emis_funky_funktions import *
|
||||
|
||||
from genir import json_to_ir
|
||||
from types_ import BUILTINS_CONTEXT
|
||||
from silly_thing import repl, repl_expr
|
||||
|
||||
import json, sys
|
||||
|
@ -8,7 +9,9 @@ import json, sys
|
|||
def main():
|
||||
match sys.argv:
|
||||
case [_, file]:
|
||||
repl_expr(json_to_ir(json.loads(open(sys.argv[1]).read())))
|
||||
# TODO handle this
|
||||
expr, ty, substs = unwrap_r(json_to_ir(json.loads(open(sys.argv[1]).read()), BUILTINS_CONTEXT))
|
||||
repl_expr(expr)
|
||||
case _:
|
||||
repl()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Collection, Sequence, TypeAlias
|
|||
|
||||
from ir import Expression, ReplHole, subst_all
|
||||
from genir import json_to_ir
|
||||
from types_ import BUILTINS_CONTEXT
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
@ -37,17 +38,18 @@ def evaluate(expr: Expression) -> Expression:
|
|||
raise AssertionError('Evaluate called on a value which cannot step:', expr)
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def repl_expr(expr: Expression, bindings: Sequence[Tuple[str, Expression]] = tuple()):
|
||||
expr_subst = subst_all(bindings, expr)
|
||||
def repl_expr(expr: Expression, bindings: ReplHole = ReplHole(BUILTINS_CONTEXT)):
|
||||
expr_subst = subst_all(bindings.val_bindings, expr)
|
||||
result = evaluate(expr_subst)
|
||||
if isinstance(result, ReplHole):
|
||||
print('Environment updated\n')
|
||||
repl(result.bindings)
|
||||
print(result.typ_bindings)
|
||||
repl(result)
|
||||
else:
|
||||
print(result, end='\n\n')
|
||||
repl(bindings)
|
||||
|
||||
def repl(bindings: Sequence[Tuple[str, Expression]] = tuple()):
|
||||
def repl(bindings: ReplHole = ReplHole(BUILTINS_CONTEXT)):
|
||||
print('Enter a JSON expression:')
|
||||
try:
|
||||
expr = input('-> ')
|
||||
|
@ -59,7 +61,9 @@ def repl(bindings: Sequence[Tuple[str, Expression]] = tuple()):
|
|||
except json.decoder.JSONDecodeError as e:
|
||||
print(f'Bad json: ', e.args[0], end='\n\n')
|
||||
return repl(bindings)
|
||||
repl_expr(json_to_ir(ast), bindings)
|
||||
# TODO handle this
|
||||
new_expr, new_ty, substs = unwrap_r(json_to_ir(ast, bindings.typ_bindings))
|
||||
repl_expr(new_expr, bindings)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
from emis_funky_funktions import *
|
||||
from typing import *
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from functools import reduce
|
||||
|
||||
MonoType: TypeAlias = 'FunctionTy | HoleTy | IntTy | VarTy'
|
||||
Substitution: TypeAlias = 'tuple[MonoType, VarTy]'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UnificationError:
|
||||
ty1: MonoType
|
||||
ty2: MonoType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VarTy:
|
||||
id: int
|
||||
name: str = field(compare=False)
|
||||
|
||||
def subst(self, replacement: MonoType, var: 'VarTy') -> MonoType:
|
||||
if var.id == self.id:
|
||||
return replacement
|
||||
else:
|
||||
return self
|
||||
|
||||
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
||||
match other:
|
||||
case VarTy(id, str) if id == self.id:
|
||||
return Ok(tuple())
|
||||
case _:
|
||||
return Ok(((other, self),))
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def free_vars(self) -> FrozenSet['VarTy']:
|
||||
return fset(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionTy:
|
||||
arg_type: MonoType
|
||||
ret_type: MonoType
|
||||
|
||||
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
||||
return FunctionTy(
|
||||
self.arg_type.subst(replacement, var),
|
||||
self.ret_type.subst(replacement, var)
|
||||
)
|
||||
|
||||
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
||||
match other:
|
||||
case VarTy(id, str) as var:
|
||||
return Ok(((self, var),))
|
||||
case FunctionTy(other_arg, other_ret):
|
||||
return self.arg_type.unify(other_arg) << (lambda arg_substs:
|
||||
subst_all_monotype(self.ret_type, arg_substs)
|
||||
.unify(
|
||||
subst_all_monotype(other_ret, arg_substs)) << (lambda ret_substs:
|
||||
Ok((*arg_substs, *ret_substs))))
|
||||
case _:
|
||||
return Err(UnificationError(self, other))
|
||||
|
||||
def free_vars(self) -> FrozenSet[VarTy]:
|
||||
return self.arg_type.free_vars() | self.ret_type.free_vars()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'({self.arg_type}) -> {self.ret_type}'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HoleTy:
|
||||
|
||||
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
||||
return self
|
||||
|
||||
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
||||
match other:
|
||||
case VarTy(id, str) as var:
|
||||
return Ok(((self, var),))
|
||||
case HoleTy():
|
||||
return Ok(tuple())
|
||||
case _:
|
||||
return Err(UnificationError(self, other))
|
||||
|
||||
def free_vars(self) -> FrozenSet[VarTy]:
|
||||
return FSet()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '[]'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IntTy:
|
||||
|
||||
def subst(self, replacement: MonoType, var: VarTy) -> MonoType:
|
||||
return self
|
||||
|
||||
def unify(self, other: MonoType) -> Result[Sequence[Substitution], UnificationError]:
|
||||
match other:
|
||||
case VarTy(id, str) as var:
|
||||
return Ok(((self, var),))
|
||||
case IntTy():
|
||||
return Ok(tuple())
|
||||
case _:
|
||||
return Err(UnificationError(self, other))
|
||||
|
||||
def free_vars(self) -> FrozenSet[VarTy]:
|
||||
return FSet()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'int'
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PolyType:
|
||||
quantified_variables: Collection[VarTy]
|
||||
monotype: MonoType
|
||||
|
||||
def subst(self, replacement: MonoType, var: VarTy) -> 'PolyType':
|
||||
if var in self.quantified_variables:
|
||||
return self
|
||||
else:
|
||||
return PolyType(
|
||||
self.quantified_variables,
|
||||
self.monotype.subst(replacement, var)
|
||||
)
|
||||
|
||||
def instantiate(self, ctx: 'Context', name: str) -> MonoType:
|
||||
match self.quantified_variables:
|
||||
case []:
|
||||
return self.monotype
|
||||
case [alpha, *rest]:
|
||||
return PolyType(
|
||||
rest,
|
||||
self.monotype.subst(ctx.new_type_var(name), alpha)
|
||||
).instantiate(ctx, name)
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def free_vars(self) -> FrozenSet[VarTy]:
|
||||
return self.monotype.free_vars() - FSet(self.quantified_variables)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if len(self.quantified_variables):
|
||||
qv = ' '.join(qv.name for qv in self.quantified_variables)
|
||||
return f'forall {qv}. {self.monotype}'
|
||||
else:
|
||||
return repr(self.monotype)
|
||||
|
||||
UNIQUE_VAR_COUNTER:int = 0
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Context:
|
||||
variable_types: Mapping[str, PolyType]
|
||||
|
||||
def with_(self, var_name: str, var_type: PolyType) -> 'Context':
|
||||
return Context({var_name: var_type, **self.variable_types})
|
||||
|
||||
def with_mono(self, var_name: str, var_type: MonoType) -> 'Context':
|
||||
return Context({var_name: PolyType([], var_type), **self.variable_types})
|
||||
|
||||
def with_many_mono(self, bindings: Sequence[tuple[str, MonoType]]) -> 'Context':
|
||||
match bindings:
|
||||
case []:
|
||||
return self
|
||||
case [(var_name, var_type), *rest]:
|
||||
return self.with_mono(var_name, var_type).with_many_mono(rest)
|
||||
raise Exception('Unreachable')
|
||||
|
||||
def subst(self, replacement: MonoType, var: VarTy) -> 'Context':
|
||||
return Context({
|
||||
name: val
|
||||
for (name, val) in self.variable_types.items()
|
||||
})
|
||||
|
||||
def subst_all(self, subst: Sequence[Tuple[MonoType, VarTy]]) -> 'Context':
|
||||
return reduce(lambda a, s: a.subst(*s), subst, self)
|
||||
|
||||
def new_type_var(self, name_prefix = 'T') -> VarTy:
|
||||
global UNIQUE_VAR_COUNTER
|
||||
UNIQUE_VAR_COUNTER += 1
|
||||
return VarTy(UNIQUE_VAR_COUNTER, f'{name_prefix}${UNIQUE_VAR_COUNTER}')
|
||||
|
||||
def instantiate(self, name: str) -> Option[MonoType]:
|
||||
if name in self.variable_types:
|
||||
return Some(self.variable_types[name].instantiate(self, name))
|
||||
else:
|
||||
return None
|
||||
|
||||
def free_vars(self) -> FrozenSet[VarTy]:
|
||||
return FSet(fv for ty in self.variable_types.values() for fv in ty.free_vars())
|
||||
|
||||
def generalize(self, mt: MonoType) -> PolyType:
|
||||
return PolyType(mt.free_vars(), mt)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self.variable_types
|
||||
|
||||
def __getitem__(self, i: str) -> PolyType:
|
||||
return self.variable_types[i]
|
||||
|
||||
BUILTINS_CONTEXT: Context = (
|
||||
Context({})
|
||||
.with_mono('+', FunctionTy(IntTy(), FunctionTy(IntTy(), IntTy())))
|
||||
.with_mono('S', FunctionTy(IntTy(), IntTy()))
|
||||
)
|
||||
|
||||
def subst_all_monotype(m: MonoType, substs: Sequence[Tuple[MonoType, VarTy]]) -> MonoType:
|
||||
return reduce(lambda a, s: a.subst(*s), substs, m)
|
Loading…
Reference in New Issue