use core::fmt; use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}, borrow::Cow, ops::FromResidual, convert::Infallible}; use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}, types::{PrimitiveType, Type}}; #[derive(Debug, Clone)] pub struct Expr { pub identifier: Identifier, pub operation: Operation, pub formals: Vec, } pub enum ExprEvalResult<'a> { Succeeded(Value), EvaluateThis(&'a LinkedList, LinkedList<(&'a Identifier, &'a Value)>), Failed(EvaluateError), } impl From for ExprEvalResult<'_> { fn from(res: evaluation::Result) -> Self { match res { Ok(v) => Self::Succeeded(v), Err(e) => Self::Failed(e), } } } impl FromResidual> for ExprEvalResult<'_> { fn from_residual(residual: Result) -> Self { match residual { Err(e) => Self::Failed(e), _ => unreachable!(), } } } impl Expr { pub fn isnt_noop(&self) -> bool { if let Operation::NoOp = self.operation { false } else { true } } pub fn evaluate<'a>(&'a self, bindings: &'a ValueBindings) -> ExprEvalResult<'a> { let arguments = self.formals .iter() .map(|formal| bindings.lookup(formal).ok_or_else(|| EvaluateError::UndefinedValue(formal.clone()))) .collect::, EvaluateError>>()?; match &self.operation { Operation::Add => arguments.into_iter() .map(TryInto::try_into) .try_fold(0, |a, b| b.map(|b: usize| a + b)) .map(Value::Int) .into(), Operation::Sub => { let mut args = arguments.into_iter() .map(TryInto::try_into); let first_arg: usize = args.next() .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a - b)) .map(Value::Int) .into() } Operation::Mul => arguments.into_iter() .map(TryInto::try_into) .try_fold(1, |a, b| b.map(|b: usize| a * b)) .map(Value::Int) .into(), Operation::Div => { let mut args = arguments.into_iter() .map(TryInto::try_into); let first_arg: usize = args.next() .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a / b)) .map(Value::Int) .into() } Operation::Mod => { let mut args = arguments.into_iter() .map(TryInto::try_into); let first_arg: usize = args.next() .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) .and_then(std::convert::identity)?; args.try_fold(first_arg, |a, b| b.map(|b| a % b)) .map(Value::Int) .into() } Operation::Range => todo!(), Operation::Eq => arguments.into_iter() .map(TryInto::try_into) .try_fold( (true, None), |(current_val, compare_to), to_compare| to_compare.map(|to_compare: usize| (current_val && if let Some(v) = compare_to { v == to_compare } else { true }, Some(to_compare)) ) ) .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) .into(), Operation::NEq => ExprEvalResult::Succeeded(if arguments.into_iter() .map(TryInto::try_into) .collect::, EvaluateError>>()? .len() == self.formals.len() { Value::Int(1) } else { Value::Int(2) }), Operation::LessThan => arguments.into_iter() .map(TryInto::try_into) .try_fold( (true, None), |(current_val, compare_to), to_compare| to_compare.map(|to_compare: usize| (current_val && compare_to.map_or(true, |compare_to| compare_to < to_compare ), Some(to_compare)) ) ) .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) .into(), Operation::GreaterThan => arguments.into_iter() .map(TryInto::try_into) .try_fold( (true, None), |(current_val, compare_to), to_compare| to_compare.map(|to_compare: usize| (current_val && compare_to.map_or(true, |compare_to| compare_to > to_compare ), Some(to_compare)) ) ) .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }) .into(), Operation::LAnd => arguments.into_iter() .map(TryInto::try_into) .try_fold(true, |a, b| b.map(|b: bool| a && b)) .map(|result| if result { Value::Int(1) } else { Value::Int(0) }) .into(), Operation::LOr => arguments.into_iter() .map(TryInto::try_into) .try_fold(true, |a, b| b.map(|b: bool| a || b)) .map(|result| if result { Value::Int(1) } else { Value::Int(0) }) .into(), Operation::Const(v) => ExprEvalResult::Succeeded(v.clone()), Operation::Call => { let mut args = arguments.into_iter(); let first_arg = args.next() .ok_or(EvaluateError::ArgumentCountMismatch(Operation::Call, 1, 0))?; match first_arg { Value::Function(formals, code) => { if formals.len() == args.size_hint().0 { let bindings = formals.into_iter() .zip(args) .collect(); ExprEvalResult::EvaluateThis(code, bindings) } else { ExprEvalResult::Failed( EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len()) ) } }, first_arg @ _ => { ExprEvalResult::Failed(EvaluateError::TypeMismatch( discriminant(&first_arg), discriminant(&Value::Function(Vec::new(), LinkedList::new())), )) } } } Operation::VariantUnion => todo!(), Operation::FunctionType => { if let [a, b] = arguments[..] { if let (Value::Type(a), Value::Type(b)) = (a, b) { ExprEvalResult::Succeeded(Value::Type(Type::Function(Box::new(a.clone()), Box::new(b.clone())))) } else { ExprEvalResult::Failed(EvaluateError::TypeMismatch(discriminant(b), discriminant(&Value::Type(Type::Primitive(PrimitiveType::Int))))) } } else { ExprEvalResult::Failed(EvaluateError::ArgumentCountMismatch(Operation::FunctionType, arguments.len(), 2)) } }, Operation::NoOp => { // Look up the value, if it's a zero-arg function, evaluate it let apparent_value = bindings.lookup(&self.identifier) .ok_or_else(|| EvaluateError::UndefinedValue(self.identifier.clone()))?; match apparent_value { Value::Function(args, exprs) if args.is_empty() => ExprEvalResult::EvaluateThis(exprs, LinkedList::new()), val => ExprEvalResult::Succeeded(val.clone()), } } Operation::Conditional(cases) => { let mut arguments = arguments; let value = arguments.pop().ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))?; let (new_bindings, code) = cases.into_iter() .find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code))) .ok_or(EvaluateError::IncompleteConditional)?; ExprEvalResult::EvaluateThis(code, new_bindings) } } } } impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Operation::NoOp = self.operation { self.identifier.fmt(f) } else { write!(f, "{} = {:?}{}", self.identifier, self.operation, self.formals.iter() .map(|f| format!(" {f}")) .collect::(), ) } } } #[derive(Clone)] pub enum Operation { Add, Sub, Mul, Div, Mod, Range, Eq, NEq, LessThan, GreaterThan, LAnd, LOr, Const(Value), Call, VariantUnion, FunctionType, NoOp, Conditional(Vec<(Pattern, LinkedList)>), } impl fmt::Debug for Operation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { Self::Add => "plus", Self::Sub => "minus", Self::Mul => "times", Self::Div => "div", Self::Mod => "mod", Self::Range => "to", Self::Eq => "equals", Self::NEq => "nequals", Self::LessThan => "lessthan", Self::GreaterThan => "morethan", Self::LAnd => "and", Self::LOr => "or", Self::Const(v) => { return write!(f, "const[{v:?}]"); }, Self::Call => "call", Self::VariantUnion => "orvariant", Self::FunctionType => "yields", Self::NoOp => "noop", Self::Conditional(branches) => { return write!(f, "cond({branches:?})"); }, }) } } pub fn set_last_ident_name(mut exprs: LinkedList, name: String) -> LinkedList { if let Some(expr) = exprs.back_mut() { let mut placeholder = Identifier::ROOT; mem::swap(&mut expr.identifier, &mut placeholder); placeholder = placeholder.set_name(name); mem::swap(&mut expr.identifier, &mut placeholder); } exprs } pub fn get_last_ident(exprs: &LinkedList) -> Option { exprs.back().map(|expr| expr.identifier.clone()) } pub fn evaluate(mut exprs: LinkedList, mut bindings: ValueBindings) -> evaluation::Result { loop { return if exprs.len() == 1 { let instr = exprs.back().unwrap(); let res = instr.evaluate(&bindings); match res { ExprEvalResult::Succeeded(v) => Ok(v), ExprEvalResult::EvaluateThis(code, new_bindings) => { let code = code.clone(); let new_bindings = new_bindings.into_iter() .map(|(ident, val)| (ident.clone(), val.clone())) .collect(); exprs = code; bindings = bindings.bind_all_owned(new_bindings); continue; // Tail-recursive "call" } ExprEvalResult::Failed(e) => Err(e), } } else { let first_instruction = exprs.pop_front().ok_or(EvaluateError::EvaluatingZeroLengthExpr)?; let res = first_instruction.evaluate(&bindings); let res_value = match res { ExprEvalResult::Succeeded(v) => v, ExprEvalResult::EvaluateThis(code, new_bindings) => { let local_scope = bindings.nested_scope() .bind_all(new_bindings); evaluate(code.clone(), local_scope)? }, ExprEvalResult::Failed(e) => return Err(e), }; exprs = exprs; bindings = bindings.bind(&first_instruction.identifier, res_value); continue; // Tail-recursive "call" } }}