From 6e69b3ea065dfe0cc9e1911d2c7d692ee88f5758 Mon Sep 17 00:00:00 2001 From: Geoffroy Couprie Date: Sat, 8 Jun 2024 17:52:18 +0200 Subject: [PATCH] tests for maps --- biscuit-auth/src/datalog/expression.rs | 209 +++++++++++++++++++++++-- 1 file changed, 193 insertions(+), 16 deletions(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index ff98a4d0..cfa8d2cb 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -548,7 +548,7 @@ mod tests { use std::collections::BTreeSet; use super::*; - use crate::datalog::{SymbolTable, TemporarySymbolTable}; + use crate::datalog::{MapKey, SymbolTable, TemporarySymbolTable}; #[test] fn negate() { @@ -1204,7 +1204,7 @@ mod tests { let res = e.evaluate(&values, &mut tmp_symbols); assert_eq!(res, Ok(Term::Bool(false))); - // array get + // get let ops = vec![ Op::Value(Term::Array(vec![ Term::Integer(0), @@ -1220,7 +1220,7 @@ mod tests { let res = e.evaluate(&values, &mut tmp_symbols); assert_eq!(res, Ok(Term::Integer(1))); - // array get out of bounds + // get out of bounds let ops = vec![ Op::Value(Term::Array(vec![ Term::Integer(0), @@ -1236,14 +1236,171 @@ mod tests { let res = e.evaluate(&values, &mut tmp_symbols); assert_eq!(res, Ok(Term::Null)); - // array get out of bounds + // all + let p = tmp_symbols.insert("param") as u32; + let ops1 = vec![ + Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::GreaterThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + + // any + let ops1 = vec![ + Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Equal), + ], + ), + Op::Binary(Binary::Any), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(false)); + } + + #[test] + fn map() { + let symbols = SymbolTable::new(); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); let ops = vec![ - Op::Value(Term::Array(vec![ - Term::Integer(0), - Term::Integer(1), - Term::Integer(2), - ])), - Op::Value(Term::Integer(3)), + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Map( + [ + (MapKey::Str(2), Term::Integer(1)), + (MapKey::Str(1), Term::Integer(0)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Map( + [(MapKey::Str(1), Term::Integer(0))] + .iter() + .cloned() + .collect(), + )), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Str(1)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + // get + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Str(2)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Integer(1))); + + // get non existing key + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Integer(0)), Op::Binary(Binary::Get), ]; @@ -1252,16 +1409,26 @@ mod tests { let res = e.evaluate(&values, &mut tmp_symbols); assert_eq!(res, Ok(Term::Null)); - // array all + // all let p = tmp_symbols.insert("param") as u32; let ops1 = vec![ - Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), Op::Closure( vec![p], vec![ Op::Value(Term::Variable(p)), - Op::Value(Term::Integer(0)), - Op::Binary(Binary::GreaterThan), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::Get), + Op::Value(Term::Integer(2)), + Op::Binary(Binary::LessThan), ], ), Op::Binary(Binary::All), @@ -1272,14 +1439,24 @@ mod tests { let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); assert_eq!(res1, Term::Bool(true)); - // array any + // any let ops1 = vec![ - Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), Op::Closure( vec![p], vec![ Op::Value(Term::Variable(p)), Op::Value(Term::Integer(0)), + Op::Binary(Binary::Get), + Op::Value(Term::Str(1)), Op::Binary(Binary::Equal), ], ),