From 1cf1acaa08253fe733261dc6a20e8801dc97003c Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Sun, 5 Mar 2023 20:45:10 -0500 Subject: [PATCH] Add in support for unification --- ir.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/ir.py b/ir.py index 5efb405..0cfc300 100644 --- a/ir.py +++ b/ir.py @@ -1,4 +1,7 @@ +from emis_funky_funktions import * + from dataclasses import dataclass +from functools import reduce from typing import Sequence, TypeAlias @dataclass(frozen=True) @@ -19,6 +22,31 @@ class Subst: def __repr__(self) -> str: return f'{self.replacement}/{self.variable}' +Substitutions: TypeAlias = Sequence[Subst] + +@dataclass(frozen=True) +class UnificationMismatch: + """ + Indicates that two terms failed to unify + + Contains the two terms, each of which is a valid subterm of one of the two original + terms, and which do not superficially unify. + """ + term1: 'IRTerm' + term2: 'IRTerm' + +@dataclass(frozen=True) +class LengthMismatch: + """ + Indicates that two clauses/argument lists failed to unify due to a length mismatch + + Contains the first element of the two lists which didn't have a corresponding term on + the other side + """ + term: 'IRTerm' + +UnificationError = UnificationMismatch | LengthMismatch + @dataclass(frozen=True) class IRProp: """ @@ -111,6 +139,86 @@ class IRNeg: return f'¬{self.inner}' IRTerm: TypeAlias = IRVar | IRProp | IRNeg +Clause: TypeAlias = Sequence[IRTerm] + +sub_all: Callable[[Substitutions, IRTerm], IRTerm] = p(reduce, lambda t, s: t.subst(s)) #type:ignore +""" +Perform a series of substitutions on a term + +Applies every substitution to the term in order + +>>> sub_all( +... [Subst('x1', IRVar('x2')), Subst('x2', IRProp('Karkat'))], +... IRProp('kismesis', [IRVar('x1'), IRVar('x2')]), +... ) +kismesis(Karkat(), Karkat()) +""" + +def unify(t1: IRTerm, t2: IRTerm) -> Result[Substitutions, UnificationError]: + """ + Attempt to find a substitution that unifies two terms + + If successful, the returned substitutions will cause both term to be equal, when + applied to both. + + If this method fails, then the pair of subterms which caused the unification to fail + are returned. + + >>> unify( + ... IRProp('imaginary', [IRProp('Rufio')]), + ... IRProp('imaginary', [IRVar('x1')]) + ... ) + Ok((Rufio()/x1,)) + + >>> unify( + ... IRProp('dating', [IRProp('Jade'), IRVar('x1')]), + ... IRProp('dating', [IRVar('x1'), IRProp('John')]) + ... ) + Err(UnificationMismatch(term1=Jade(), term2=John())) + """ + match (t1, t2): + case (IRVar(v1), IRVar(v2)) if v1 == v2: + return Ok(tuple()) + case (IRVar(v), t_other) | (t_other, IRVar(v)):#type:ignore #TODO if v not in t_other: + return Ok((Subst(v, t_other),)) + case (IRProp(n1, a1), IRProp(n2, a2)) if n1 == n2 and len(a1) == len(a2): + return unify_clauses(a1, a2) + case (IRNeg(i1), IRNeg(i2)): + return unify(i1, i2) + return Err(UnificationMismatch(t1, t2)) + +def unify_clauses(c1: Clause, c2: Clause) -> Result[Substitutions, UnificationError]: + """ + Attempt to perform unification on two clauses or argument lists + + See `unify()` for the details of how this works. When working with clauses, the same + rules apply. The substitutions, when applied to every term of both clauses, will + cause the clauses to become exactly the same. + + Lists which are not the same length cannot be unified, and will always fail. + + >>> unify_clauses( + ... [ IRProp('imaginary', [IRProp('Rufio')]), IRProp('friend', [IRVar('x1'), IRVar('x3')]) ], + ... [ IRProp('imaginary', [IRVar('x1')]), IRProp('friend', [IRVar('x2'), IRProp('Tavros')]) ] + ... ) + Ok((Rufio()/x1, Rufio()/x2, Tavros()/x3)) + + >>> unify_clauses( + ... [ IRProp('imaginary', [IRProp('Rufio')]), IRProp('friend', [IRVar('x1'), IRVar('x3')]) ], + ... [ IRProp('imaginary', [IRVar('x1')]) ] + ... ) + Err(LengthMismatch(term=friend(Rufio(),*x3))) + """ + match (c1, c2): + case ([], []): + return Ok(tuple()) + case ([h1, *t1], [h2, *t2]): + return unify(h1, h2) << (lambda subs: + unify_clauses((*map(p(sub_all,subs),t1),), (*map(p(sub_all,subs),t2),)) <= ( + lambda final_subs: (*subs, *final_subs))) + case ([h, *t], []) | ([], [h, *t]): + return Err(LengthMismatch(h)) + raise Exception('Unreachable') if __name__ == '__main__': import doctest