diff --git a/src/ir/evaluation.rs b/src/ir/evaluation.rs index 9712ae0..159c0d1 100644 --- a/src/ir/evaluation.rs +++ b/src/ir/evaluation.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, mem::Discriminant}; +use std::{collections::{HashMap, LinkedList}, mem::Discriminant}; use super::{Identifier, value::Value, expr::Operation, types::Type}; @@ -11,6 +11,14 @@ impl<'a> ValueBindings<'a> { self } + pub fn bind_all(mut self, other: LinkedList<(&Identifier, &Value)>) -> Self{ + self.1.extend( + other.into_iter() + .map(|(ident, val)| (ident.1 as u128, val.clone())) + ); + self + } + pub fn lookup(&self, ident: &Identifier) -> Option<&Value> { self.1.get(&(ident.1 as u128)) .or_else(|| self.0.and_then(|binding| binding.lookup(ident))) @@ -33,6 +41,7 @@ pub enum EvaluateError { TypeMismatch(Discriminant, Discriminant), NoMain, MainHasArgs, + IncompleteConditional, } pub type Result = std::result::Result; diff --git a/src/ir/expr.rs b/src/ir/expr.rs index e624b80..c35f658 100644 --- a/src/ir/expr.rs +++ b/src/ir/expr.rs @@ -153,7 +153,17 @@ impl Expr { Operation::VariantUnion => todo!(), Operation::FunctionType => todo!(), Operation::NoOp => Ok(Value::Int(0)), - Operation::Conditional(_) => todo!(), + Operation::Conditional(cases) => { + let mut arguments = arguments; + let value = + arguments.pop().ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0))?; + let (new_bindings, code) = cases.into_iter() + .find_map(|(pattern, code)| pattern.matches(&value).map(|bindings| (bindings, code))) + .ok_or(EvaluateError::IncompleteConditional)?; + let local_scope = bindings.nested_scope() + .bind_all(new_bindings); + evaluate(code.clone(), &local_scope) + } } } } diff --git a/src/ir/pattern.rs b/src/ir/pattern.rs index 5803a63..fb354d0 100644 --- a/src/ir/pattern.rs +++ b/src/ir/pattern.rs @@ -1,6 +1,8 @@ -use crate::token::Literal; +use std::collections::LinkedList; -use super::Identifier; +use crate::{token::Literal, join}; + +use super::{Identifier, value::Value}; #[derive(Clone, Debug)] pub enum Pattern { @@ -8,3 +10,34 @@ pub enum Pattern { Binding(Identifier), Variant(Identifier, Vec), } + +impl Pattern { + pub fn matches<'a, 'b>(&'a self, value: &'b Value) -> Option> { + match (value, self) { + (Value::Int(x1), Pattern::Literal(Literal::Int(x2))) => + if *x1 as u64 == *x2 { + Some(LinkedList::new()) + } else { + None + }, + (Value::Structural(ident1, values1), Pattern::Variant(ident2, patterns2)) => + if ident1 == ident2 && values1.len() == patterns2.len() { + values1.into_iter() + .zip(patterns2) + .fold(Some(LinkedList::new()), |bindings, (v1, p2)| + bindings.and_then(|bindings| + p2.matches(v1).map(|new_bindings| + join( + bindings, + new_bindings, + ) + ) + ) + ) + } else { + None + }, + _ => None, + } + } +} diff --git a/src/ir/value.rs b/src/ir/value.rs index 027d9fa..0f6c321 100644 --- a/src/ir/value.rs +++ b/src/ir/value.rs @@ -8,7 +8,7 @@ pub enum Value { Int(usize), String(String), Function(Vec, LinkedList), - Structural(usize, Vec), + Structural(Identifier, Vec), } impl fmt::Display for Value { diff --git a/src/main.rs b/src/main.rs index 15ae257..94b4de9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,6 +60,8 @@ fn main() -> std::io::Result<()> { println!("Huh, there's no main method"), Err(EvaluateError::MainHasArgs) => println!("Your main method has args, but that's not allowed"), + Err(EvaluateError::IncompleteConditional) => + println!("Uh oh, a conditional somewhere isn't complete!"), } } Err(e) =>