338 lines
9.9 KiB
Rust
338 lines
9.9 KiB
Rust
use core::fmt;
|
|
use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}, borrow::Cow, ops::FromResidual, convert::Infallible};
|
|
|
|
use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}, types::{PrimitiveType, Type}};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Expr {
|
|
pub identifier: Identifier,
|
|
pub operation: Operation,
|
|
pub formals: Vec<Identifier>,
|
|
}
|
|
|
|
pub enum ExprEvalResult<'a> {
|
|
Succeeded(Value),
|
|
EvaluateThis(&'a LinkedList<Expr>, LinkedList<(&'a Identifier, &'a Value)>),
|
|
Failed(EvaluateError),
|
|
}
|
|
|
|
impl From<evaluation::Result> for ExprEvalResult<'_> {
|
|
fn from(res: evaluation::Result) -> Self {
|
|
match res {
|
|
Ok(v) => Self::Succeeded(v),
|
|
Err(e) => Self::Failed(e),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FromResidual<Result<Infallible, EvaluateError>> for ExprEvalResult<'_> {
|
|
fn from_residual(residual: Result<Infallible, EvaluateError>) -> Self {
|
|
match residual {
|
|
Err(e) => Self::Failed(e),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Expr {
|
|
pub fn isnt_noop(&self) -> bool {
|
|
if let Operation::NoOp = self.operation {
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
}
|
|
pub fn evaluate<'a>(&'a self, bindings: &'a ValueBindings) -> ExprEvalResult<'a> {
|
|
let arguments = self.formals
|
|
.iter()
|
|
.map(|formal| bindings.lookup(formal).ok_or_else(|| EvaluateError::UndefinedValue(formal.clone())))
|
|
.collect::<Result<Vec<_>, EvaluateError>>()?;
|
|
match &self.operation {
|
|
Operation::Add =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(0, |a, b| b.map(|b: usize| a + b))
|
|
.map(Value::Int)
|
|
.into(),
|
|
Operation::Sub => {
|
|
let mut args = arguments.into_iter()
|
|
.map(TryInto::try_into);
|
|
let first_arg: usize = args.next()
|
|
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
|
|
.and_then(std::convert::identity)?;
|
|
args.try_fold(first_arg, |a, b| b.map(|b| a - b))
|
|
.map(Value::Int)
|
|
.into()
|
|
}
|
|
|
|
Operation::Mul =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(1, |a, b| b.map(|b: usize| a * b))
|
|
.map(Value::Int)
|
|
.into(),
|
|
Operation::Div => {
|
|
let mut args = arguments.into_iter()
|
|
.map(TryInto::try_into);
|
|
let first_arg: usize = args.next()
|
|
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
|
|
.and_then(std::convert::identity)?;
|
|
args.try_fold(first_arg, |a, b| b.map(|b| a / b))
|
|
.map(Value::Int)
|
|
.into()
|
|
}
|
|
Operation::Mod => {
|
|
let mut args = arguments.into_iter()
|
|
.map(TryInto::try_into);
|
|
let first_arg: usize = args.next()
|
|
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
|
|
.and_then(std::convert::identity)?;
|
|
args.try_fold(first_arg, |a, b| b.map(|b| a % b))
|
|
.map(Value::Int)
|
|
.into()
|
|
}
|
|
Operation::Range => todo!(),
|
|
Operation::Eq =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(
|
|
(true, None),
|
|
|(current_val, compare_to), to_compare|
|
|
to_compare.map(|to_compare: usize|
|
|
(current_val && if let Some(v) = compare_to {
|
|
v == to_compare
|
|
} else {
|
|
true
|
|
}, Some(to_compare))
|
|
)
|
|
)
|
|
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
|
|
.into(),
|
|
Operation::NEq =>
|
|
ExprEvalResult::Succeeded(if
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.collect::<Result<HashSet<usize>, EvaluateError>>()?
|
|
.len() == self.formals.len()
|
|
{ Value::Int(1) } else { Value::Int(2) }),
|
|
Operation::LessThan =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(
|
|
(true, None),
|
|
|(current_val, compare_to), to_compare|
|
|
to_compare.map(|to_compare: usize|
|
|
(current_val && compare_to.map_or(true, |compare_to|
|
|
compare_to < to_compare
|
|
), Some(to_compare))
|
|
)
|
|
)
|
|
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
|
|
.into(),
|
|
Operation::GreaterThan =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(
|
|
(true, None),
|
|
|(current_val, compare_to), to_compare|
|
|
to_compare.map(|to_compare: usize|
|
|
(current_val && compare_to.map_or(true, |compare_to|
|
|
compare_to > to_compare
|
|
), Some(to_compare))
|
|
)
|
|
)
|
|
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) })
|
|
.into(),
|
|
Operation::LAnd =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(true, |a, b| b.map(|b: bool| a && b))
|
|
.map(|result| if result { Value::Int(1) } else { Value::Int(0) })
|
|
.into(),
|
|
Operation::LOr =>
|
|
arguments.into_iter()
|
|
.map(TryInto::try_into)
|
|
.try_fold(true, |a, b| b.map(|b: bool| a || b))
|
|
.map(|result| if result { Value::Int(1) } else { Value::Int(0) })
|
|
.into(),
|
|
Operation::Const(v) => ExprEvalResult::Succeeded(v.clone()),
|
|
Operation::Call => {
|
|
let mut args = arguments.into_iter();
|
|
let first_arg = args.next()
|
|
.ok_or(EvaluateError::ArgumentCountMismatch(Operation::Call, 1, 0))?;
|
|
match first_arg {
|
|
Value::Function(formals, code) => {
|
|
if formals.len() == args.size_hint().0 {
|
|
let bindings = formals.into_iter()
|
|
.zip(args)
|
|
.collect();
|
|
ExprEvalResult::EvaluateThis(code, bindings)
|
|
} else {
|
|
ExprEvalResult::Failed(
|
|
EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len())
|
|
)
|
|
}
|
|
},
|
|
first_arg @ _ => {
|
|
ExprEvalResult::Failed(EvaluateError::TypeMismatch(
|
|
discriminant(&first_arg),
|
|
discriminant(&Value::Function(Vec::new(), LinkedList::new())),
|
|
))
|
|
}
|
|
}
|
|
}
|
|
Operation::VariantUnion => todo!(),
|
|
Operation::FunctionType => {
|
|
if let [a, b] = arguments[..] {
|
|
if let (Value::Type(a), Value::Type(b)) = (a, b) {
|
|
ExprEvalResult::Succeeded(Value::Type(Type::Function(Box::new(a.clone()), Box::new(b.clone()))))
|
|
} else {
|
|
ExprEvalResult::Failed(EvaluateError::TypeMismatch(discriminant(b), discriminant(&Value::Type(Type::Primitive(PrimitiveType::Int)))))
|
|
}
|
|
} else {
|
|
ExprEvalResult::Failed(EvaluateError::ArgumentCountMismatch(Operation::FunctionType, arguments.len(), 2))
|
|
}
|
|
},
|
|
Operation::NoOp => {
|
|
// Look up the value, if it's a zero-arg function, evaluate it
|
|
let apparent_value =
|
|
bindings.lookup(&self.identifier)
|
|
.ok_or_else(|| EvaluateError::UndefinedValue(self.identifier.clone()))?;
|
|
match apparent_value {
|
|
Value::Function(args, exprs) if args.is_empty() =>
|
|
ExprEvalResult::EvaluateThis(exprs, LinkedList::new()),
|
|
val => ExprEvalResult::Succeeded(val.clone()),
|
|
}
|
|
}
|
|
Operation::Conditional(cases) => {
|
|
let mut arguments = arguments;
|
|
let value =
|
|
arguments.pop().ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))?;
|
|
let (new_bindings, code) = cases.into_iter()
|
|
.find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code)))
|
|
.ok_or(EvaluateError::IncompleteConditional)?;
|
|
ExprEvalResult::EvaluateThis(code, new_bindings)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Expr {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
if let Operation::NoOp = self.operation {
|
|
self.identifier.fmt(f)
|
|
} else {
|
|
write!(f,
|
|
"{} = {:?}{}",
|
|
self.identifier,
|
|
self.operation,
|
|
self.formals.iter()
|
|
.map(|f| format!(" {f}"))
|
|
.collect::<String>(),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub enum Operation {
|
|
Add,
|
|
Sub,
|
|
Mul,
|
|
Div,
|
|
Mod,
|
|
Range,
|
|
Eq,
|
|
NEq,
|
|
LessThan,
|
|
GreaterThan,
|
|
LAnd,
|
|
LOr,
|
|
Const(Value),
|
|
Call,
|
|
VariantUnion,
|
|
FunctionType,
|
|
NoOp,
|
|
Conditional(Vec<(Pattern, LinkedList<Expr>)>),
|
|
}
|
|
|
|
impl fmt::Debug for Operation {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.write_str(match self {
|
|
Self::Add => "plus",
|
|
Self::Sub => "minus",
|
|
Self::Mul => "times",
|
|
Self::Div => "div",
|
|
Self::Mod => "mod",
|
|
Self::Range => "to",
|
|
Self::Eq => "equals",
|
|
Self::NEq => "nequals",
|
|
Self::LessThan => "lessthan",
|
|
Self::GreaterThan => "morethan",
|
|
Self::LAnd => "and",
|
|
Self::LOr => "or",
|
|
Self::Const(v) => {
|
|
return write!(f, "const[{v:?}]");
|
|
},
|
|
Self::Call => "call",
|
|
Self::VariantUnion => "orvariant",
|
|
Self::FunctionType => "yields",
|
|
Self::NoOp => "noop",
|
|
Self::Conditional(branches) => {
|
|
return write!(f, "cond({branches:?})");
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
pub fn set_last_ident_name(mut exprs: LinkedList<Expr>, name: String) -> LinkedList<Expr> {
|
|
if let Some(expr) = exprs.back_mut() {
|
|
let mut placeholder = Identifier::ROOT;
|
|
mem::swap(&mut expr.identifier, &mut placeholder);
|
|
placeholder = placeholder.set_name(name);
|
|
mem::swap(&mut expr.identifier, &mut placeholder);
|
|
}
|
|
exprs
|
|
}
|
|
|
|
pub fn get_last_ident(exprs: &LinkedList<Expr>) -> Option<Identifier> {
|
|
exprs.back().map(|expr| expr.identifier.clone())
|
|
}
|
|
|
|
pub fn evaluate(mut exprs: LinkedList<Expr>, mut bindings: ValueBindings) -> evaluation::Result { loop { return
|
|
if exprs.len() == 1 {
|
|
let instr = exprs.back().unwrap();
|
|
let res = instr.evaluate(&bindings);
|
|
match res {
|
|
ExprEvalResult::Succeeded(v) => Ok(v),
|
|
ExprEvalResult::EvaluateThis(code, new_bindings) => {
|
|
let code = code.clone();
|
|
let new_bindings = new_bindings.into_iter()
|
|
.map(|(ident, val)| (ident.clone(), val.clone()))
|
|
.collect();
|
|
|
|
exprs = code;
|
|
bindings = bindings.bind_all_owned(new_bindings);
|
|
continue; // Tail-recursive "call"
|
|
}
|
|
ExprEvalResult::Failed(e) => Err(e),
|
|
}
|
|
} else {
|
|
let first_instruction = exprs.pop_front().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?;
|
|
let res = first_instruction.evaluate(&bindings);
|
|
let res_value = match res {
|
|
ExprEvalResult::Succeeded(v) => v,
|
|
ExprEvalResult::EvaluateThis(code, new_bindings) => {
|
|
let local_scope = bindings.nested_scope()
|
|
.bind_all(new_bindings);
|
|
evaluate(code.clone(), local_scope)?
|
|
},
|
|
ExprEvalResult::Failed(e) => return Err(e),
|
|
};
|
|
|
|
exprs = exprs;
|
|
bindings = bindings.bind(&first_instruction.identifier, res_value);
|
|
continue; // Tail-recursive "call"
|
|
}
|
|
}}
|