Amo/src/ir/expr.rs

277 lines
8.0 KiB
Rust

use core::fmt;
use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}};
use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}};
#[derive(Debug, Clone)]
pub struct Expr {
pub identifier: Identifier,
pub operation: Operation,
pub formals: Vec<Identifier>,
}
impl Expr {
pub fn isnt_noop(&self) -> bool {
if let Operation::NoOp = self.operation {
false
} else {
true
}
}
pub fn evaluate(&self, bindings: &ValueBindings) -> evaluation::Result {
let arguments = self.formals
.iter()
.map(|formal| bindings.lookup(formal).cloned().ok_or_else(|| EvaluateError::UndefinedValue(formal.clone())))
.collect::<Result<Vec<_>, EvaluateError>>()?;
match &self.operation {
Operation::Add =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(0, |a, b| b.map(|b: usize| a + b))
.map(Value::Int),
Operation::Sub => {
let mut args = arguments.into_iter()
.map(TryInto::try_into);
let first_arg: usize = args.next()
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
.and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a - b))
.map(Value::Int)
}
Operation::Mul =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(1, |a, b| b.map(|b: usize| a * b))
.map(Value::Int),
Operation::Div => {
let mut args = arguments.into_iter()
.map(TryInto::try_into);
let first_arg: usize = args.next()
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
.and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a / b))
.map(Value::Int)
}
Operation::Mod => {
let mut args = arguments.into_iter()
.map(TryInto::try_into);
let first_arg: usize = args.next()
.ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))
.and_then(std::convert::identity)?;
args.try_fold(first_arg, |a, b| b.map(|b| a % b))
.map(Value::Int)
}
Operation::Range => todo!(),
Operation::Eq =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(
(true, None),
|(current_val, compare_to), to_compare|
to_compare.map(|to_compare: usize|
(current_val && if let Some(v) = compare_to {
v == to_compare
} else {
true
}, Some(to_compare))
)
)
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }),
Operation::NEq =>
if
arguments.into_iter()
.map(TryInto::try_into)
.collect::<Result<HashSet<usize>, EvaluateError>>()?
.len() == self.formals.len()
{ Ok(Value::Int(1)) } else { Ok(Value::Int(2)) },
Operation::LessThan =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(
(true, None),
|(current_val, compare_to), to_compare|
to_compare.map(|to_compare: usize|
(current_val && compare_to.map_or(true, |compare_to|
compare_to < to_compare
), Some(to_compare))
)
)
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }),
Operation::GreaterThan =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(
(true, None),
|(current_val, compare_to), to_compare|
to_compare.map(|to_compare: usize|
(current_val && compare_to.map_or(true, |compare_to|
compare_to > to_compare
), Some(to_compare))
)
)
.map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }),
Operation::LAnd =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(true, |a, b| b.map(|b: bool| a && b))
.map(|result| if result { Value::Int(1) } else { Value::Int(0) }),
Operation::LOr =>
arguments.into_iter()
.map(TryInto::try_into)
.try_fold(true, |a, b| b.map(|b: bool| a || b))
.map(|result| if result { Value::Int(1) } else { Value::Int(0) }),
Operation::Const(v) => Ok(v.clone()),
Operation::Call => {
let mut args = arguments.into_iter();
let first_arg = args.next()
.ok_or(EvaluateError::ArgumentCountMismatch(Operation::Call, 1, 0))?;
match first_arg {
Value::Function(formals, code) => {
if formals.len() == args.size_hint().0 {
let new_scope = bindings.nested_scope();
let bound_scope = formals.into_iter()
.zip(args)
.fold(
new_scope,
|scope, (formal, value)|
scope.bind(&formal, value)
);
evaluate(code, &bound_scope)
} else {
Err(EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len()))
}
},
first_arg @ _ => {
Err(EvaluateError::TypeMismatch(
discriminant(&first_arg),
discriminant(&Value::Function(Vec::new(), LinkedList::new())),
))
}
}
}
Operation::VariantUnion => todo!(),
Operation::FunctionType => todo!(),
Operation::NoOp => Ok(Value::Int(0)),
Operation::Conditional(cases) => {
let mut arguments = arguments;
let value =
arguments.pop().ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))?;
let (new_bindings, code) = cases.into_iter()
.find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code)))
.ok_or(EvaluateError::IncompleteConditional)?;
let local_scope = bindings.nested_scope()
.bind_all(new_bindings);
evaluate(code.clone(), &local_scope)
}
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Operation::NoOp = self.operation {
self.identifier.fmt(f)
} else {
write!(f,
"{} = {:?}{}",
self.identifier,
self.operation,
self.formals.iter()
.map(|f| format!(" {f}"))
.collect::<String>(),
)
}
}
}
#[derive(Clone)]
pub enum Operation {
Add,
Sub,
Mul,
Div,
Mod,
Range,
Eq,
NEq,
LessThan,
GreaterThan,
LAnd,
LOr,
Const(Value),
Call,
VariantUnion,
FunctionType,
NoOp,
Conditional(Vec<(Pattern, LinkedList<Expr>)>),
}
impl fmt::Debug for Operation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Add => "plus",
Self::Sub => "minus",
Self::Mul => "times",
Self::Div => "div",
Self::Mod => "mod",
Self::Range => "to",
Self::Eq => "equals",
Self::NEq => "nequals",
Self::LessThan => "lessthan",
Self::GreaterThan => "morethan",
Self::LAnd => "and",
Self::LOr => "or",
Self::Const(v) => {
return write!(f, "const[{v:?}]");
},
Self::Call => "call",
Self::VariantUnion => "orvariant",
Self::FunctionType => "yields",
Self::NoOp => "noop",
Self::Conditional(branches) => {
return write!(f, "cond({branches:?})");
},
})
}
}
pub fn set_last_ident_name(mut exprs: LinkedList<Expr>, name: String) -> LinkedList<Expr> {
if let Some(expr) = exprs.back_mut() {
let mut placeholder = Identifier::ROOT;
mem::swap(&mut expr.identifier, &mut placeholder);
placeholder = placeholder.set_name(name);
mem::swap(&mut expr.identifier, &mut placeholder);
}
exprs
}
pub fn get_last_ident(exprs: &LinkedList<Expr>) -> Option<Identifier> {
exprs.back().map(|expr| expr.identifier.clone())
}
pub fn evaluate(exprs: LinkedList<Expr>, bindings: &ValueBindings) -> evaluation::Result {
let bindings = bindings.nested_scope();
if exprs.iter().all(|e| !e.isnt_noop()) {
bindings.lookup(&exprs.back().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?.identifier)
.ok_or_else(|| EvaluateError::UndefinedValue(exprs.back().unwrap().identifier.clone()))
.cloned()
} else {
let (all_bindings, last_ident) = exprs.into_iter()
.filter(Expr::isnt_noop)
.try_fold(
(bindings, None),
|(bindings, _last_ident), expr| {
let last_value = expr.evaluate(&bindings)?;
Ok((bindings.bind(&expr.identifier, last_value.clone()), Some(expr.identifier)))
})?;
last_ident.ok_or(EvaluateError::EvaluatingZeroLengthExpr)
.and_then(
|ident|
all_bindings.lookup(&ident)
.cloned()
.ok_or(EvaluateError::UndefinedValue(ident))
)
}
}