Got tail-call optimized evaluation working

This commit is contained in:
Emi Simpson 2022-04-25 12:23:10 -04:00
parent 7766a8ae30
commit d298170022
Signed by: Emi
GPG Key ID: A12F2C2FFDC3D847
5 changed files with 115 additions and 58 deletions

View File

@ -2,7 +2,7 @@ use std::{collections::{HashMap, LinkedList}, mem::Discriminant};
use super::{Identifier, value::Value, expr::Operation, types::Type}; use super::{Identifier, value::Value, expr::Operation, types::Type};
#[derive(Debug, Default)] #[derive(Debug, Default, Clone)]
pub struct ValueBindings<'a>(Option<&'a ValueBindings<'a>>, HashMap<u128, Value>); pub struct ValueBindings<'a>(Option<&'a ValueBindings<'a>>, HashMap<u128, Value>);
impl<'a> ValueBindings<'a> { impl<'a> ValueBindings<'a> {
@ -19,6 +19,14 @@ impl<'a> ValueBindings<'a> {
self self
} }
pub fn bind_all_owned(mut self, other: LinkedList<(Identifier, Value)>) -> Self{
self.1.extend(
other.into_iter()
.map(|(ident, val)| (ident.1 as u128, val))
);
self
}
pub fn lookup(&self, ident: &Identifier) -> Option<&Value> { pub fn lookup(&self, ident: &Identifier) -> Option<&Value> {
self.1.get(&(ident.1 as u128)) self.1.get(&(ident.1 as u128))
.or_else(|| self.0.and_then(|binding| binding.lookup(ident))) .or_else(|| self.0.and_then(|binding| binding.lookup(ident)))

View File

@ -1,5 +1,5 @@
use core::fmt; use core::fmt;
use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}}; use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}, borrow::Cow, ops::FromResidual, convert::Infallible};
use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}}; use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}};
@ -10,6 +10,30 @@ pub struct Expr {
pub formals: Vec<Identifier>, pub formals: Vec<Identifier>,
} }
pub enum ExprEvalResult<'a> {
Succeeded(Value),
EvaluateThis(&'a LinkedList<Expr>, LinkedList<(&'a Identifier, &'a Value)>),
Failed(EvaluateError),
}
impl From<evaluation::Result> for ExprEvalResult<'_> {
fn from(res: evaluation::Result) -> Self {
match res {
Ok(v) => Self::Succeeded(v),
Err(e) => Self::Failed(e),
}
}
}
impl FromResidual<Result<Infallible, EvaluateError>> for ExprEvalResult<'_> {
fn from_residual(residual: Result<Infallible, EvaluateError>) -> Self {
match residual {
Err(e) => Self::Failed(e),
_ => unreachable!(),
}
}
}
impl Expr { impl Expr {
pub fn isnt_noop(&self) -> bool { pub fn isnt_noop(&self) -> bool {
if let Operation::NoOp = self.operation { if let Operation::NoOp = self.operation {
@ -18,17 +42,18 @@ impl Expr {
true true
} }
} }
pub fn evaluate(&self, bindings: &ValueBindings) -> evaluation::Result { pub fn evaluate<'a>(&'a self, bindings: &'a ValueBindings) -> ExprEvalResult<'a> {
let arguments = self.formals let arguments = self.formals
.iter() .iter()
.map(|formal| bindings.lookup(formal).cloned().ok_or_else(|| EvaluateError::UndefinedValue(formal.clone()))) .map(|formal| bindings.lookup(formal).ok_or_else(|| EvaluateError::UndefinedValue(formal.clone())))
.collect::<Result<Vec<_>, EvaluateError>>()?; .collect::<Result<Vec<_>, EvaluateError>>()?;
match &self.operation { match &self.operation {
Operation::Add => Operation::Add =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
.try_fold(0, |a, b| b.map(|b: usize| a + b)) .try_fold(0, |a, b| b.map(|b: usize| a + b))
.map(Value::Int), .map(Value::Int)
.into(),
Operation::Sub => { Operation::Sub => {
let mut args = arguments.into_iter() let mut args = arguments.into_iter()
.map(TryInto::try_into); .map(TryInto::try_into);
@ -37,13 +62,15 @@ impl Expr {
.and_then(std::convert::identity)?; .and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a - b)) args.try_fold(first_arg, |a, b| b.map(|b| a - b))
.map(Value::Int) .map(Value::Int)
.into()
} }
Operation::Mul => Operation::Mul =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
.try_fold(1, |a, b| b.map(|b: usize| a * b)) .try_fold(1, |a, b| b.map(|b: usize| a * b))
.map(Value::Int), .map(Value::Int)
.into(),
Operation::Div => { Operation::Div => {
let mut args = arguments.into_iter() let mut args = arguments.into_iter()
.map(TryInto::try_into); .map(TryInto::try_into);
@ -52,6 +79,7 @@ impl Expr {
.and_then(std::convert::identity)?; .and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a / b)) args.try_fold(first_arg, |a, b| b.map(|b| a / b))
.map(Value::Int) .map(Value::Int)
.into()
} }
Operation::Mod => { Operation::Mod => {
let mut args = arguments.into_iter() let mut args = arguments.into_iter()
@ -61,6 +89,7 @@ impl Expr {
.and_then(std::convert::identity)?; .and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a % b)) args.try_fold(first_arg, |a, b| b.map(|b| a % b))
.map(Value::Int) .map(Value::Int)
.into()
} }
Operation::Range => todo!(), Operation::Range => todo!(),
Operation::Eq => Operation::Eq =>
@ -77,14 +106,15 @@ impl Expr {
}, Some(to_compare)) }, Some(to_compare))
) )
) )
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
.into(),
Operation::NEq => Operation::NEq =>
if ExprEvalResult::Succeeded(if
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
.collect::<Result<HashSet<usize>, EvaluateError>>()? .collect::<Result<HashSet<usize>, EvaluateError>>()?
.len() == self.formals.len() .len() == self.formals.len()
{ Ok(Value::Int(1)) } else { Ok(Value::Int(2)) }, { Value::Int(1) } else { Value::Int(2) }),
Operation::LessThan => Operation::LessThan =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
@ -97,7 +127,8 @@ impl Expr {
), Some(to_compare)) ), Some(to_compare))
) )
) )
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
.into(),
Operation::GreaterThan => Operation::GreaterThan =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
@ -110,18 +141,21 @@ impl Expr {
), Some(to_compare)) ), Some(to_compare))
) )
) )
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
.into(),
Operation::LAnd => Operation::LAnd =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
.try_fold(true, |a, b| b.map(|b: bool| a && b)) .try_fold(true, |a, b| b.map(|b: bool| a && b))
.map(|result| if result { Value::Int(1) } else { Value::Int(0) }), .map(|result| if result { Value::Int(1) } else { Value::Int(0) })
.into(),
Operation::LOr => Operation::LOr =>
arguments.into_iter() arguments.into_iter()
.map(TryInto::try_into) .map(TryInto::try_into)
.try_fold(true, |a, b| b.map(|b: bool| a || b)) .try_fold(true, |a, b| b.map(|b: bool| a || b))
.map(|result| if result { Value::Int(1) } else { Value::Int(0) }), .map(|result| if result { Value::Int(1) } else { Value::Int(0) })
Operation::Const(v) => Ok(v.clone()), .into(),
Operation::Const(v) => ExprEvalResult::Succeeded(v.clone()),
Operation::Call => { Operation::Call => {
let mut args = arguments.into_iter(); let mut args = arguments.into_iter();
let first_arg = args.next() let first_arg = args.next()
@ -129,21 +163,18 @@ impl Expr {
match first_arg { match first_arg {
Value::Function(formals, code) => { Value::Function(formals, code) => {
if formals.len() == args.size_hint().0 { if formals.len() == args.size_hint().0 {
let new_scope = bindings.nested_scope(); let bindings = formals.into_iter()
let bound_scope = formals.into_iter()
.zip(args) .zip(args)
.fold( .collect();
new_scope, ExprEvalResult::EvaluateThis(code, bindings)
|scope, (formal, value)|
scope.bind(&formal, value)
);
evaluate(code, &bound_scope)
} else { } else {
Err(EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len())) ExprEvalResult::Failed(
EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len())
)
} }
}, },
first_arg @ _ => { first_arg @ _ => {
Err(EvaluateError::TypeMismatch( ExprEvalResult::Failed(EvaluateError::TypeMismatch(
discriminant(&first_arg), discriminant(&first_arg),
discriminant(&Value::Function(Vec::new(), LinkedList::new())), discriminant(&Value::Function(Vec::new(), LinkedList::new())),
)) ))
@ -152,7 +183,7 @@ impl Expr {
} }
Operation::VariantUnion => todo!(), Operation::VariantUnion => todo!(),
Operation::FunctionType => todo!(), Operation::FunctionType => todo!(),
Operation::NoOp => Ok(Value::Int(0)), Operation::NoOp => ExprEvalResult::Succeeded(Value::Int(0)),
Operation::Conditional(cases) => { Operation::Conditional(cases) => {
let mut arguments = arguments; let mut arguments = arguments;
let value = let value =
@ -160,16 +191,14 @@ impl Expr {
let (new_bindings, code) = cases.into_iter() let (new_bindings, code) = cases.into_iter()
.find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code))) .find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code)))
.ok_or(EvaluateError::IncompleteConditional)?; .ok_or(EvaluateError::IncompleteConditional)?;
let local_scope = bindings.nested_scope() ExprEvalResult::EvaluateThis(code, new_bindings)
.bind_all(new_bindings);
evaluate(code.clone(), &local_scope)
} }
} }
} }
} }
impl fmt::Display for Expr { impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Operation::NoOp = self.operation { if let Operation::NoOp = self.operation {
self.identifier.fmt(f) self.identifier.fmt(f)
} else { } else {
@ -182,7 +211,7 @@ impl fmt::Display for Expr {
.collect::<String>(), .collect::<String>(),
) )
} }
} }
} }
#[derive(Clone)] #[derive(Clone)]
@ -208,7 +237,7 @@ pub enum Operation {
} }
impl fmt::Debug for Operation { impl fmt::Debug for Operation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self { f.write_str(match self {
Self::Add => "plus", Self::Add => "plus",
Self::Sub => "minus", Self::Sub => "minus",
@ -233,7 +262,7 @@ impl fmt::Debug for Operation {
return write!(f, "cond({branches:?})"); return write!(f, "cond({branches:?})");
}, },
}) })
} }
} }
pub fn set_last_ident_name(mut exprs: LinkedList<Expr>, name: String) -> LinkedList<Expr> { pub fn set_last_ident_name(mut exprs: LinkedList<Expr>, name: String) -> LinkedList<Expr> {
@ -250,31 +279,49 @@ pub fn get_last_ident(exprs: &LinkedList<Expr>) -> Option<Identifier> {
exprs.back().map(|expr| expr.identifier.clone()) exprs.back().map(|expr| expr.identifier.clone())
} }
pub fn evaluate(exprs: LinkedList<Expr>, bindings: &ValueBindings) -> evaluation::Result { pub fn evaluate(mut exprs: LinkedList<Expr>, mut bindings: ValueBindings) -> evaluation::Result { loop { return
if exprs.iter().all(|e| !e.isnt_noop()) { if exprs.iter().all(|e| !e.isnt_noop()) {
bindings.lookup(&exprs.back().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?.identifier) bindings.lookup(&exprs.back().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?.identifier)
.ok_or_else(|| EvaluateError::UndefinedValue(exprs.back().unwrap().identifier.clone())) .ok_or_else(|| EvaluateError::UndefinedValue(exprs.back().unwrap().identifier.clone()))
.cloned() .cloned()
} else if exprs.len() == 1 { } else if exprs.len() == 1 {
exprs.back() let instr = exprs.back().unwrap();
.unwrap() let res = instr.evaluate(&bindings);
.evaluate(&bindings) match res {
ExprEvalResult::Succeeded(v) => Ok(v),
ExprEvalResult::EvaluateThis(code, new_bindings) => {
let code = code.clone();
let new_bindings = new_bindings.into_iter()
.map(|(ident, val)| (ident.clone(), val.clone()))
.collect();
exprs = code;
bindings = bindings.bind_all_owned(new_bindings);
continue; // Tail-recursive "call"
}
ExprEvalResult::Failed(e) => Err(e),
}
} else { } else {
let bindings = bindings.nested_scope(); let first_instruction = exprs.pop_front().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?;
let (all_bindings, last_ident) = exprs.into_iter() if first_instruction.isnt_noop() {
.filter(Expr::isnt_noop) let res = first_instruction.evaluate(&bindings);
.try_fold( let res_value = match res {
(bindings, None), ExprEvalResult::Succeeded(v) => v,
|(bindings, _last_ident), expr| { ExprEvalResult::EvaluateThis(code, new_bindings) => {
let last_value = expr.evaluate(&bindings)?; let local_scope = bindings.nested_scope()
Ok((bindings.bind(&expr.identifier, last_value.clone()), Some(expr.identifier))) .bind_all(new_bindings);
})?; evaluate(code.clone(), local_scope)?
last_ident.ok_or(EvaluateError::EvaluatingZeroLengthExpr) },
.and_then( ExprEvalResult::Failed(e) => return Err(e),
|ident| };
all_bindings.lookup(&ident)
.cloned() exprs = exprs;
.ok_or(EvaluateError::UndefinedValue(ident)) bindings = bindings.bind(&first_instruction.identifier, res_value);
) continue; // Tail-recursive "call"
} else {
exprs = exprs;
bindings = bindings;
continue; // Tail-recursive "call"
}
} }
} }}

View File

@ -109,7 +109,7 @@ pub fn hasty_evaluate(program: Vec<Declaration>) -> evaluation::Result {
if let Some(main) = bindings.get_main() { if let Some(main) = bindings.get_main() {
if let Value::Function(args, code) = main { if let Value::Function(args, code) = main {
if args.is_empty() { if args.is_empty() {
evaluate(code.clone(), &bindings) evaluate(code.clone(), bindings)
} else { } else {
Err(EvaluateError::MainHasArgs) Err(EvaluateError::MainHasArgs)
} }

View File

@ -31,12 +31,12 @@ impl fmt::Display for Value {
} }
} }
impl TryInto<usize> for Value { impl TryInto<usize> for &Value {
type Error = EvaluateError; type Error = EvaluateError;
fn try_into(self) -> Result<usize, EvaluateError> { fn try_into(self) -> Result<usize, EvaluateError> {
if let Value::Int(v) = self { if let Value::Int(v) = self {
Ok(v) Ok(*v)
} else { } else {
Err(EvaluateError::TypeMismatch( Err(EvaluateError::TypeMismatch(
discriminant(&Value::Int(0)), discriminant(&Value::Int(0)),
@ -46,12 +46,12 @@ impl TryInto<usize> for Value {
} }
} }
impl TryInto<bool> for Value { impl TryInto<bool> for &Value {
type Error = EvaluateError; type Error = EvaluateError;
fn try_into(self) -> Result<bool, EvaluateError> { fn try_into(self) -> Result<bool, EvaluateError> {
if let Value::Int(v) = self { if let Value::Int(v) = self {
Ok(v != 0) Ok(*v != 0)
} else { } else {
Err(EvaluateError::TypeMismatch( Err(EvaluateError::TypeMismatch(
discriminant(&Value::Int(0)), discriminant(&Value::Int(0)),

View File

@ -1,3 +1,5 @@
#![feature(try_trait_v2)]
use std::{fs::File, io::Read, collections::LinkedList, process::exit}; use std::{fs::File, io::Read, collections::LinkedList, process::exit};
use logos::Logos; use logos::Logos;