diff --git a/sample-initial.amo b/sample-initial.amo index 2f4e8d4..474c32c 100644 --- a/sample-initial.amo +++ b/sample-initial.amo @@ -1,6 +1,5 @@ main: Int -main = 1 + (if (distance_squared 10 2) > 10 then 1000 else 2000) - +main = 1 + distance_squared 4 3 distance_squared: Int, Int -> Int distance_squared a b = a * a + b * b diff --git a/src/ir/evaluation.rs b/src/ir/evaluation.rs new file mode 100644 index 0000000..9712ae0 --- /dev/null +++ b/src/ir/evaluation.rs @@ -0,0 +1,38 @@ +use std::{collections::HashMap, mem::Discriminant}; + +use super::{Identifier, value::Value, expr::Operation, types::Type}; + +#[derive(Debug, Default)] +pub struct ValueBindings<'a>(Option<&'a ValueBindings<'a>>, HashMap); + +impl<'a> ValueBindings<'a> { + pub fn bind(mut self, ident: &Identifier, value: Value) -> Self { + self.1.insert(ident.1 as u128, value); + self + } + + pub fn lookup(&self, ident: &Identifier) -> Option<&Value> { + self.1.get(&(ident.1 as u128)) + .or_else(|| self.0.and_then(|binding| binding.lookup(ident))) + } + + pub fn nested_scope(&'a self) -> ValueBindings<'a> { + ValueBindings(Some(self), HashMap::new()) + } + + pub fn get_main(&self) -> Option<&Value> { + self.1.get(&3).or_else(|| self.0.and_then(ValueBindings::get_main)) + } +} + +pub enum EvaluateError { + UndefinedValue(Identifier), + EvaluatingZeroLengthExpr, + ArgumentCountMismatch(Operation, usize, usize), + FunctionArgumentCountMismatch(usize, usize), + TypeMismatch(Discriminant, Discriminant), + NoMain, + MainHasArgs, +} + +pub type Result = std::result::Result; diff --git a/src/ir/expr.rs b/src/ir/expr.rs index e62fdc6..e624b80 100644 --- a/src/ir/expr.rs +++ b/src/ir/expr.rs @@ -1,15 +1,163 @@ use core::fmt; -use std::{collections::LinkedList, mem}; +use std::{collections::{LinkedList, HashSet}, mem::{self, discriminant}}; -use super::{value::Value, pattern::Pattern, Identifier}; +use super::{value::Value, pattern::Pattern, Identifier, evaluation::{ValueBindings, self, EvaluateError}}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Expr { pub identifier: Identifier, pub operation: Operation, pub formals: Vec, } +impl Expr { + pub fn isnt_noop(&self) -> bool { + if let Operation::NoOp = self.operation { + false + } else { + true + } + } + pub fn evaluate(&self, bindings: &ValueBindings) -> evaluation::Result { + let arguments = self.formals + .iter() + .map(|formal| bindings.lookup(formal).cloned().ok_or_else(|| EvaluateError::UndefinedValue(formal.clone()))) + .collect::, EvaluateError>>()?; + match &self.operation { + Operation::Add => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold(0, |a, b| b.map(|b: usize| a + b)) + .map(Value::Int), + Operation::Sub => { + let mut args = arguments.into_iter() + .map(TryInto::try_into); + let first_arg: usize = args.next() + .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) + .and_then(std::convert::identity)?; + args.try_fold(first_arg, |a, b| b.map(|b| a - b)) + .map(Value::Int) + } + + Operation::Mul => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold(1, |a, b| b.map(|b: usize| a * b)) + .map(Value::Int), + Operation::Div => { + let mut args = arguments.into_iter() + .map(TryInto::try_into); + let first_arg: usize = args.next() + .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) + .and_then(std::convert::identity)?; + args.try_fold(first_arg, |a, b| b.map(|b| a / b)) + .map(Value::Int) + } + Operation::Mod => { + let mut args = arguments.into_iter() + .map(TryInto::try_into); + let first_arg: usize = args.next() + .ok_or(EvaluateError::ArgumentCountMismatch(self.operation.clone(), 1, 0)) + .and_then(std::convert::identity)?; + args.try_fold(first_arg, |a, b| b.map(|b| a % b)) + .map(Value::Int) + } + Operation::Range => todo!(), + Operation::Eq => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold( + (true, None), + |(current_val, compare_to), to_compare| + to_compare.map(|to_compare: usize| + (current_val && if let Some(v) = compare_to { + v == to_compare + } else { + true + }, Some(to_compare)) + ) + ) + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + Operation::NEq => + if + arguments.into_iter() + .map(TryInto::try_into) + .collect::, EvaluateError>>()? + .len() == self.formals.len() + { Ok(Value::Int(1)) } else { Ok(Value::Int(2)) }, + Operation::LessThan => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold( + (true, None), + |(current_val, compare_to), to_compare| + to_compare.map(|to_compare: usize| + (current_val && compare_to.map_or(true, |compare_to| + compare_to < to_compare + ), Some(to_compare)) + ) + ) + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + Operation::GreaterThan => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold( + (true, None), + |(current_val, compare_to), to_compare| + to_compare.map(|to_compare: usize| + (current_val && compare_to.map_or(true, |compare_to| + compare_to > to_compare + ), Some(to_compare)) + ) + ) + .map(|(result, _)| if result { Value::Int(1) } else { Value::Int(0) }), + Operation::LAnd => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold(true, |a, b| b.map(|b: bool| a && b)) + .map(|result| if result { Value::Int(1) } else { Value::Int(0) }), + Operation::LOr => + arguments.into_iter() + .map(TryInto::try_into) + .try_fold(true, |a, b| b.map(|b: bool| a || b)) + .map(|result| if result { Value::Int(1) } else { Value::Int(0) }), + Operation::Const(v) => Ok(v.clone()), + Operation::Call => { + let mut args = arguments.into_iter(); + let first_arg = args.next() + .ok_or(EvaluateError::ArgumentCountMismatch(Operation::Call, 1, 0))?; + match first_arg { + Value::Function(formals, code) => { + if formals.len() == args.size_hint().0 { + let new_scope = bindings.nested_scope(); + let bound_scope = formals.into_iter() + .zip(args) + .fold( + new_scope, + |scope, (formal, value)| + scope.bind(&formal, value) + ); + evaluate(code, &bound_scope) + } else { + Err(EvaluateError::FunctionArgumentCountMismatch(args.size_hint().0, formals.len())) + } + }, + first_arg @ _ => { + Err(EvaluateError::TypeMismatch( + discriminant(&first_arg), + discriminant(&Value::Function(Vec::new(), LinkedList::new())), + )) + } + } + } + Operation::VariantUnion => todo!(), + Operation::FunctionType => todo!(), + Operation::NoOp => Ok(Value::Int(0)), + Operation::Conditional(_) => todo!(), + } + } +} + impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Operation::NoOp = self.operation { @@ -27,6 +175,7 @@ impl fmt::Display for Expr { } } +#[derive(Clone)] pub enum Operation { Add, Sub, @@ -90,3 +239,22 @@ pub fn set_last_ident_name(mut exprs: LinkedList, name: String) -> LinkedL pub fn get_last_ident(exprs: &LinkedList) -> Option { exprs.back().map(|expr| expr.identifier.clone()) } + +pub fn evaluate(exprs: LinkedList, bindings: &ValueBindings) -> evaluation::Result { + let bindings = bindings.nested_scope(); + let (all_bindings, last_ident) = exprs.into_iter() + .filter(Expr::isnt_noop) + .try_fold( + (bindings, None), + |(bindings, _last_ident), expr| { + let last_value = expr.evaluate(&bindings)?; + Ok((bindings.bind(&expr.identifier, last_value.clone()), Some(expr.identifier))) + })?; + last_ident.ok_or(EvaluateError::EvaluatingZeroLengthExpr) + .and_then( + |ident| + all_bindings.lookup(&ident) + .cloned() + .ok_or(EvaluateError::UndefinedValue(ident)) + ) +} diff --git a/src/ir/mod.rs b/src/ir/mod.rs index 4c213a7..549ad3a 100644 --- a/src/ir/mod.rs +++ b/src/ir/mod.rs @@ -1,13 +1,14 @@ use core::fmt; -use crate::cons; +use crate::{cons, ir::expr::evaluate}; -use self::{types::Type, value::Value}; +use self::{types::Type, value::Value, evaluation::{ValueBindings, EvaluateError}}; pub mod types; pub mod expr; pub mod value; pub mod pattern; +pub mod evaluation; #[derive(Debug, Clone, Eq, PartialEq)] // name id depth @@ -16,6 +17,11 @@ pub struct Identifier(pub String, usize, usize); impl Identifier { pub const ROOT: Identifier = Identifier(String::new(), 1, 0); + /// A special identifier used only by the `main` method + pub fn main() -> Identifier { + Identifier("main".to_string(), 3, 1) + } + pub fn subid_with_name(&self, new_name: String, index: u32) -> Identifier { Identifier ( new_name, @@ -94,6 +100,27 @@ impl fmt::Display for Declaration { } } +pub fn hasty_evaluate(program: Vec) -> evaluation::Result { + let bindings = ValueBindings::default(); + let bindings = program.into_iter() + .fold(bindings, |bindings, declaration| + bindings.bind(&declaration.identifier, declaration.value) + ); + if let Some(main) = bindings.get_main() { + if let Value::Function(args, code) = main { + if args.is_empty() { + evaluate(code.clone(), &bindings) + } else { + Err(EvaluateError::MainHasArgs) + } + } else { + Ok(main.clone()) + } + } else { + Err(EvaluateError::NoMain) + } +} + #[derive(Debug, Clone)] pub struct BindingScope<'a>(Option<&'a BindingScope<'a>>, Vec); diff --git a/src/ir/pattern.rs b/src/ir/pattern.rs index f1135bd..5803a63 100644 --- a/src/ir/pattern.rs +++ b/src/ir/pattern.rs @@ -2,7 +2,7 @@ use crate::token::Literal; use super::Identifier; -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Pattern { Literal(Literal), Binding(Identifier), diff --git a/src/ir/value.rs b/src/ir/value.rs index c0375a0..027d9fa 100644 --- a/src/ir/value.rs +++ b/src/ir/value.rs @@ -1,9 +1,9 @@ use core::fmt; -use std::collections::LinkedList; +use std::{collections::LinkedList, mem::discriminant}; -use super::{expr::Expr, Identifier}; +use super::{expr::Expr, Identifier, evaluation::EvaluateError}; -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum Value { Int(usize), String(String), @@ -30,3 +30,33 @@ impl fmt::Display for Value { } } } + +impl TryInto for Value { + type Error = EvaluateError; + + fn try_into(self) -> Result { + if let Value::Int(v) = self { + Ok(v) + } else { + Err(EvaluateError::TypeMismatch( + discriminant(&Value::Int(0)), + discriminant(&self), + )) + } + } +} + +impl TryInto for Value { + type Error = EvaluateError; + + fn try_into(self) -> Result { + if let Value::Int(v) = self { + Ok(v != 0) + } else { + Err(EvaluateError::TypeMismatch( + discriminant(&Value::Int(0)), + discriminant(&self), + )) + } + } +} diff --git a/src/main.rs b/src/main.rs index ff9057b..15ae257 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ use std::{fs::File, io::Read, collections::LinkedList, process::exit}; use logos::Logos; use parser::{program::Program, Parsable, WrappedLexer}; +use crate::ir::{hasty_evaluate, evaluation::EvaluateError}; + mod token; mod parser; mod ir; @@ -34,13 +36,32 @@ fn main() -> std::io::Result<()> { println!("Building IR..."); match program.gen_ir() { - Ok(program) => + Ok(program) => { println!( "IR built successfully!\n{}", - program.into_iter() - .map(|decl| format!("\n\n{decl}")) + program.iter() + .map(|decl| format!("\n\n{decl}\n\n")) .collect::(), - ), + ); + match hasty_evaluate(program) { + Ok(value) => + println!("Evaluated successfully! Got: {value}"), + Err(EvaluateError::UndefinedValue(v)) => + println!("Hit an error: {v} is undefined"), + Err(EvaluateError::EvaluatingZeroLengthExpr) => + println!("Hit an error: Tried to evaluate a series of zero instructions"), + Err(EvaluateError::ArgumentCountMismatch(op, real, ex)) => + println!("Problem while evaluating operation {op:?}: expected {ex} arguments, got {real}"), + Err(EvaluateError::FunctionArgumentCountMismatch(real, ex)) => + println!("Problem while evaluating a function: expected {ex} arguments, got {real}"), + Err(EvaluateError::TypeMismatch(a, b)) => + println!("Type mismatch between {a:?} and {b:?}!"), + Err(EvaluateError::NoMain) => + println!("Huh, there's no main method"), + Err(EvaluateError::MainHasArgs) => + println!("Your main method has args, but that's not allowed"), + } + } Err(e) => println!("Oh noes! {e:?}"), } diff --git a/src/parser/declaration.rs b/src/parser/declaration.rs index 10fe3f1..3279f56 100644 --- a/src/parser/declaration.rs +++ b/src/parser/declaration.rs @@ -51,11 +51,11 @@ impl Parsable for Declaration { impl Declaration { pub fn gen_ir(self, bindings: &BindingScope, index: u32, parent_id: Identifier) -> Result, IrError> { + let own_id = self.get_identifier(&parent_id, index).unwrap(); //TODO fix me let Declaration { name, type_, name2, args, value } = self; if name != name2 { Err(IrError::MismatchedNames(name, name2)) } else { - let own_id = parent_id.subid_with_name(name, index); let args = args.into_iter() .enumerate() .map(|(i, variable_name)| own_id.subid_with_name(variable_name, i as u32)) @@ -74,6 +74,10 @@ impl Declaration { } pub fn get_identifier(&self, parent: &Identifier, index: u32) -> Option { - Some(parent.subid_with_name(self.name.clone(), index)) + if self.name.eq("main") { + Some(Identifier::main()) + } else { + Some(parent.subid_with_name(self.name.clone(), index)) + } } }