From d2981700224be5e1bbaa73f0667ecf1d937bff8c Mon Sep 17 00:00:00 2001 From: Emi Simpson Date: Mon, 25 Apr 2022 12:23:10 -0400 Subject: [PATCH] Got tail-call optimized evaluation working --- src/ir/evaluation.rs | 10 ++- src/ir/expr.rs | 151 ++++++++++++++++++++++++++++--------------- src/ir/mod.rs | 2 +- src/ir/value.rs | 8 +-- src/main.rs | 2 + 5 files changed, 115 insertions(+), 58 deletions(-) diff --git a/src/ir/evaluation.rs b/src/ir/evaluation.rs index 159c0d1..a09de46 100644 --- a/src/ir/evaluation.rs +++ b/src/ir/evaluation.rs @@ -2,7 +2,7 @@ use std::{collections::{HashMap, LinkedList}, mem::Discriminant}; use super::{Identifier, value::Value, expr::Operation, types::Type}; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct ValueBindings<'a>(Option<&'a ValueBindings<'a>>, HashMap); impl<'a> ValueBindings<'a> { @@ -19,6 +19,14 @@ impl<'a> ValueBindings<'a> { self } + pub fn bind_all_owned(mut self, other: LinkedList<(Identifier, Value)>) -> Self{ + self.1.extend( + other.into_iter() + .map(|(ident, val)| (ident.1 as u128, val)) + ); + self + } + pub fn lookup(&self, ident: &Identifier) -> Option<&Value> { self.1.get(&(ident.1 as u128)) .or_else(|| self.0.and_then(|binding| binding.lookup(ident))) diff --git a/src/ir/expr.rs b/src/ir/expr.rs index bab0d9e..8e42848 100644 --- a/src/ir/expr.rs +++ b/src/ir/expr.rs @@ -1,5 +1,5 @@ use core::fmt; -use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}}; +use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}, borrow::Cow, ops::FromResidual, convert::Infallible}; use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}}; @@ -10,6 +10,30 @@ pub struct Expr { pub formals: Vec, } +pub enum ExprEvalResult<'a> { + Succeeded(Value), + EvaluateThis(&'a LinkedList, LinkedList<(&'a Identifier, &'a Value)>), + Failed(EvaluateError), +} + +impl From for ExprEvalResult<'_> { + fn from(res: evaluation::Result) -> Self { + match res { + Ok(v) => Self::Succeeded(v), + Err(e) => Self::Failed(e), + } + } +} + +impl FromResidual> for ExprEvalResult<'_> { + fn from_residual(residual: Result) -> Self { + match residual { + Err(e) => Self::Failed(e), + _ => unreachable!(), + } + } +} + impl Expr { pub fn isnt_noop(&self) -> bool { if let Operation::NoOp = self.operation { @@ -18,17 +42,18 @@ impl Expr { true } } - pub fn evaluate(&self, bindings: &ValueBindings) -> evaluation::Result { + pub fn evaluate<'a>(&'a self, bindings: &'a ValueBindings) -> ExprEvalResult<'a> { let arguments = self.formals .iter() - .map(|formal| bindings.lookup(formal).cloned().ok_or_else(|| EvaluateError::UndefinedValue(formal.clone()))) + .map(|formal| bindings.lookup(formal).ok_or_else(|| EvaluateError::UndefinedValue(formal.clone()))) .collect::, EvaluateError>>()?; match &self.operation { Operation::Add => arguments.into_iter() .map(TryInto::try_into) .try_fold(0, |a, b| b.map(|b: usize| a + b)) - .map(Value::Int), + .map(Value::Int) + .into(), Operation::Sub => { let mut args = arguments.into_iter() .map(TryInto::try_into); @@ -37,13 +62,15 @@ impl Expr { .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a - b)) .map(Value::Int) + .into() } Operation::Mul => arguments.into_iter() .map(TryInto::try_into) .try_fold(1, |a, b| b.map(|b: usize| a * b)) - .map(Value::Int), + .map(Value::Int) + .into(), Operation::Div => { let mut args = arguments.into_iter() .map(TryInto::try_into); @@ -52,6 +79,7 @@ impl Expr { .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a / b)) .map(Value::Int) + .into() } Operation::Mod => { let mut args = arguments.into_iter() @@ -61,6 +89,7 @@ impl Expr { .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a % b)) .map(Value::Int) + .into() } Operation::Range => todo!(), Operation::Eq => @@ -77,14 +106,15 @@ impl Expr { }, Some(to_compare)) ) ) - .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) + .into(), Operation::NEq => - if + ExprEvalResult::Succeeded(if arguments.into_iter() .map(TryInto::try_into) .collect::, EvaluateError>>()? .len() == self.formals.len() - { Ok(Value::Int(1)) } else { Ok(Value::Int(2)) }, + { Value::Int(1) } else { Value::Int(2) }), Operation::LessThan => arguments.into_iter() .map(TryInto::try_into) @@ -97,7 +127,8 @@ impl Expr { ), Some(to_compare)) ) ) - .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) + .into(), Operation::GreaterThan => arguments.into_iter() .map(TryInto::try_into) @@ -110,18 +141,21 @@ impl Expr { ), Some(to_compare)) ) ) - .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) + .into(), Operation::LAnd => arguments.into_iter() .map(TryInto::try_into) .try_fold(true, |a, b| b.map(|b: bool| a && b)) - .map(|result| if result { Value::Int(1) } else { Value::Int(0) }), + .map(|result| if result { Value::Int(1) } else { Value::Int(0) }) + .into(), Operation::LOr => arguments.into_iter() .map(TryInto::try_into) .try_fold(true, |a, b| b.map(|b: bool| a || b)) - .map(|result| if result { Value::Int(1) } else { Value::Int(0) }), - Operation::Const(v) => Ok(v.clone()), + .map(|result| if result { Value::Int(1) } else { Value::Int(0) }) + .into(), + Operation::Const(v) => ExprEvalResult::Succeeded(v.clone()), Operation::Call => { let mut args = arguments.into_iter(); let first_arg = args.next() @@ -129,21 +163,18 @@ impl Expr { match first_arg { Value::Function(formals, code) => { if formals.len() == args.size_hint().0 { - let new_scope = bindings.nested_scope(); - let bound_scope = formals.into_iter() + let bindings = formals.into_iter() .zip(args) - .fold( - new_scope, - |scope, (formal, value)| - scope.bind(&formal, value) - ); - evaluate(code, &bound_scope) + .collect(); + ExprEvalResult::EvaluateThis(code, bindings) } else { - Err(EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len())) + ExprEvalResult::Failed( + EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len()) + ) } }, first_arg @ _ => { - Err(EvaluateError::TypeMismatch( + ExprEvalResult::Failed(EvaluateError::TypeMismatch( discriminant(&first_arg), discriminant(&Value::Function(Vec::new(), LinkedList::new())), )) @@ -152,7 +183,7 @@ impl Expr { } Operation::VariantUnion => todo!(), Operation::FunctionType => todo!(), - Operation::NoOp => Ok(Value::Int(0)), + Operation::NoOp => ExprEvalResult::Succeeded(Value::Int(0)), Operation::Conditional(cases) => { let mut arguments = arguments; let value = @@ -160,16 +191,14 @@ impl Expr { let (new_bindings, code) = cases.into_iter() .find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code))) .ok_or(EvaluateError::IncompleteConditional)?; - let local_scope = bindings.nested_scope() - .bind_all(new_bindings); - evaluate(code.clone(), &local_scope) + ExprEvalResult::EvaluateThis(code, new_bindings) } } } } impl fmt::Display for Expr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Operation::NoOp = self.operation { self.identifier.fmt(f) } else { @@ -182,7 +211,7 @@ impl fmt::Display for Expr { .collect::(), ) } - } + } } #[derive(Clone)] @@ -208,7 +237,7 @@ pub enum Operation { } impl fmt::Debug for Operation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { Self::Add => "plus", Self::Sub => "minus", @@ -233,7 +262,7 @@ impl fmt::Debug for Operation { return write!(f, "cond({branches:?})"); }, }) - } + } } pub fn set_last_ident_name(mut exprs: LinkedList, name: String) -> LinkedList { @@ -250,31 +279,49 @@ pub fn get_last_ident(exprs: &LinkedList) -> Option { exprs.back().map(|expr| expr.identifier.clone()) } -pub fn evaluate(exprs: LinkedList, bindings: &ValueBindings) -> evaluation::Result { +pub fn evaluate(mut exprs: LinkedList, mut bindings: ValueBindings) -> evaluation::Result { loop { return if exprs.iter().all(|e| !e.isnt_noop()) { bindings.lookup(&exprs.back().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?.identifier) .ok_or_else(|| EvaluateError::UndefinedValue(exprs.back().unwrap().identifier.clone())) .cloned() } else if exprs.len() == 1 { - exprs.back() - .unwrap() - .evaluate(&bindings) + let instr = exprs.back().unwrap(); + let res = instr.evaluate(&bindings); + match res { + ExprEvalResult::Succeeded(v) => Ok(v), + ExprEvalResult::EvaluateThis(code, new_bindings) => { + let code = code.clone(); + let new_bindings = new_bindings.into_iter() + .map(|(ident, val)| (ident.clone(), val.clone())) + .collect(); + + exprs = code; + bindings = bindings.bind_all_owned(new_bindings); + continue; // Tail-recursive "call" + } + ExprEvalResult::Failed(e) => Err(e), + } } else { - let bindings = bindings.nested_scope(); - let (all_bindings, last_ident) = exprs.into_iter() - .filter(Expr::isnt_noop) - .try_fold( - (bindings, None), - |(bindings, _last_ident), expr| { - let last_value = expr.evaluate(&bindings)?; - Ok((bindings.bind(&expr.identifier, last_value.clone()), Some(expr.identifier))) - })?; - last_ident.ok_or(EvaluateError::EvaluatingZeroLengthExpr) - .and_then( - |ident| - all_bindings.lookup(&ident) - .cloned() - .ok_or(EvaluateError::UndefinedValue(ident)) - ) + let first_instruction = exprs.pop_front().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?; + if first_instruction.isnt_noop() { + let res = first_instruction.evaluate(&bindings); + let res_value = match res { + ExprEvalResult::Succeeded(v) => v, + ExprEvalResult::EvaluateThis(code, new_bindings) => { + let local_scope = bindings.nested_scope() + .bind_all(new_bindings); + evaluate(code.clone(), local_scope)? + }, + ExprEvalResult::Failed(e) => return Err(e), + }; + + exprs = exprs; + bindings = bindings.bind(&first_instruction.identifier, res_value); + continue; // Tail-recursive "call" + } else { + exprs = exprs; + bindings = bindings; + continue; // Tail-recursive "call" + } } -} +}} diff --git a/src/ir/mod.rs b/src/ir/mod.rs index 549ad3a..e007e04 100644 --- a/src/ir/mod.rs +++ b/src/ir/mod.rs @@ -109,7 +109,7 @@ pub fn hasty_evaluate(program: Vec) -> evaluation::Result { if let Some(main) = bindings.get_main() { if let Value::Function(args, code) = main { if args.is_empty() { - evaluate(code.clone(), &bindings) + evaluate(code.clone(), bindings) } else { Err(EvaluateError::MainHasArgs) } diff --git a/src/ir/value.rs b/src/ir/value.rs index 0f6c321..720fbee 100644 --- a/src/ir/value.rs +++ b/src/ir/value.rs @@ -31,12 +31,12 @@ impl fmt::Display for Value { } } -impl TryInto for Value { +impl TryInto for &Value { type Error = EvaluateError; fn try_into(self) -> Result { if let Value::Int(v) = self { - Ok(v) + Ok(*v) } else { Err(EvaluateError::TypeMismatch( discriminant(&Value::Int(0)), @@ -46,12 +46,12 @@ impl TryInto for Value { } } -impl TryInto for Value { +impl TryInto for &Value { type Error = EvaluateError; fn try_into(self) -> Result { if let Value::Int(v) = self { - Ok(v != 0) + Ok(*v != 0) } else { Err(EvaluateError::TypeMismatch( discriminant(&Value::Int(0)), diff --git a/src/main.rs b/src/main.rs index 94b4de9..3de041d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![feature(try_trait_v2)] + use std::{fs::File, io::Read, collections::LinkedList, process::exit}; use logos::Logos;