diff --git a/biscuit-auth/Cargo.toml b/biscuit-auth/Cargo.toml index 4e100841..1b6caa58 100644 --- a/biscuit-auth/Cargo.toml +++ b/biscuit-auth/Cargo.toml @@ -49,6 +49,7 @@ uuid = { version = "1", optional = true } biscuit-parser = { version = "0.1.2", path = "../biscuit-parser" } biscuit-quote = { version = "0.2.2", optional = true, path = "../biscuit-quote" } chrono = { version = "0.4.26", optional = true, default-features = false, features = ["serde"] } +serde_json = "1.0.117" [dev-dependencies] diff --git a/biscuit-auth/examples/testcases.rs b/biscuit-auth/examples/testcases.rs index e0558a1b..aec0992a 100644 --- a/biscuit-auth/examples/testcases.rs +++ b/biscuit-auth/examples/testcases.rs @@ -154,6 +154,8 @@ fn run(target: String, root_key: Option, test: bool, json: bool) { add_test_result(&mut results, closures(&target, &root, test)); + add_test_result(&mut results, array_map(&target, &root, test)); + if json { let s = serde_json::to_string_pretty(&TestCases { root_private_key: hex::encode(root.private().to_bytes()), @@ -1330,22 +1332,22 @@ fn expressions(target: &str, root: &KeyPair, test: bool) -> TestResult { check if hex:12ab === hex:12ab; // set contains - check if [1, 2].contains(2); - check if [2020-12-04T09:46:41+00:00, 2019-12-04T09:46:41+00:00].contains(2020-12-04T09:46:41+00:00); - check if [true, false, true].contains(true); - check if ["abc", "def"].contains("abc"); - check if [hex:12ab, hex:34de].contains(hex:34de); - check if [1, 2].contains([2]); + check if {1, 2}.contains(2); + check if { 2020-12-04T09:46:41+00:00, 2019-12-04T09:46:41+00:00}.contains(2020-12-04T09:46:41+00:00); + check if {true, false, true}.contains(true); + check if {"abc", "def"}.contains("abc"); + check if {hex:12ab, hex:34de}.contains(hex:34de); + check if {1, 2}.contains({2}); // set strict equal - check if [1, 2] === [1, 2]; + check if {1, 2} === {1, 2}; // set intersection - check if [1, 2].intersection([2, 3]) === [2]; + check if {1, 2}.intersection({2, 3}) === {2}; // set union - check if [1, 2].union([2, 3]) === [1, 2, 3]; + check if {1, 2}.union({2, 3}) === {1, 2, 3}; // chained method calls - check if [1, 2, 3].intersection([1, 2]).contains(1); + check if {1, 2, 3}.intersection({1, 2}).contains(1); // chained method calls with unary method - check if [1, 2, 3].intersection([1, 2]).length() === 2; + check if {1, 2, 3}.intersection({1, 2}).length() === 2; "#) .build_with_rng(&root, SymbolTable::default(), &mut rng) .unwrap(); @@ -2091,15 +2093,15 @@ fn closures(target: &str, root: &KeyPair, test: bool) -> TestResult { // boolean or laziness check if true || "x".intersection("x"); // all - check if [1,2,3].all($p -> $p > 0); + check if {1,2,3}.all($p -> $p > 0); // all - check if ![1,2,3].all($p -> $p == 2); + check if !{1,2,3}.all($p -> $p == 2); // any - check if [1,2,3].any($p -> $p > 2); + check if {1,2,3}.any($p -> $p > 2); // any - check if ![1,2,3].any($p -> $p > 3); + check if !{1,2,3}.any($p -> $p > 3); // nested closures - check if [1,2,3].any($p -> $p > 1 && [3,4,5].any($q -> $p == $q)); + check if {1,2,3}.any($p -> $p > 1 && {3,4,5}.any($q -> $p == $q)); "# ) .build_with_rng(&root, SymbolTable::default(), &mut rng) @@ -2131,6 +2133,59 @@ fn closures(target: &str, root: &KeyPair, test: bool) -> TestResult { } } +fn array_map(target: &str, root: &KeyPair, test: bool) -> TestResult { + let mut rng: StdRng = SeedableRng::seed_from_u64(1234); + let title = "test array and map operations (v5 blocks)".to_string(); + let filename = "test033_array_map".to_string(); + let token; + + let biscuit = biscuit!( + r#" + // array + check if [1, 2, 1].length() == 3; + check if ["a", "b"] != [1, 2, 3]; + check if ["a", "b"] == ["a", "b"]; + check if ["a", "b", "c"].contains("c"); + check if [1, 2, 3].starts_with([1, 2]); + check if [4, 5, 6 ].ends_with([6]); + check if [1,2, "a"].get(2) == "a"; + check if [1, 2].get(3) == null; + check if [1,2,3].all($p -> $p > 0); + check if [1,2,3].any($p -> $p > 2); + // map + check if { "a": 1 , "b": 2, "c": 3, "d": 4}.length() == 4; + check if { 1: "a" , 2: "b"} != { "a": 1 , "b": 2}; + check if { 1: "a" , 2: "b"} == { 2: "b", 1: "a" }; + check if { "a": 1 , "b": 2, "c": 3, "d": 4}.contains("d"); + check if { "a": 1 , "b": 2, 1: "A" }.get("a") == 1; + check if { "a": 1 , "b": 2, 1: "A" }.get(1) == "A"; + check if { "a": 1 , "b": 2, 1: "A" }.get("c") == null; + check if { "a": 1 , "b": 2, 1: "A" }.get(2) == null; + check if { "a": 1 , "b": 2 }.all($kv -> $kv.get(0) != "c" && $kv.get(1) < 3 ); + check if { "a": 1 , "b": 2, 1: "A" }.any($kv -> $kv.get(0) == 1 && $kv.get(1) == "A" ); + // nesting + check if { "user": { "id": 1, "roles": ["admin"] } }.get("user").get("roles").contains("admin"); + "# + ) + .build_with_rng(&root, SymbolTable::default(), &mut rng) + .unwrap(); + token = print_blocks(&biscuit); + + let data = write_or_load_testcase(target, &filename, root, &biscuit, test); + + let mut validations = BTreeMap::new(); + validations.insert( + "".to_string(), + validate_token(root, &data[..], "allow if true"), + ); + + TestResult { + title, + filename, + token, + validations, + } +} fn print_blocks(token: &Biscuit) -> Vec { let mut v = Vec::new(); diff --git a/biscuit-auth/samples/README.md b/biscuit-auth/samples/README.md index 667cbedb..27a8027d 100644 --- a/biscuit-auth/samples/README.md +++ b/biscuit-auth/samples/README.md @@ -841,7 +841,7 @@ allow if true; revocation ids: - `c46d071ff3f33434223c8305fdad529f62bf78bb5d9cbfc2a345d4bca6bf314014840e18ba353f86fdb9073d58b12b8c872ac1f8e593c2e9064b90f6c2ede006` -- `a0c4c163a0b3ca406df4ece3d1371356190df04208eccef72f77e875ed0531b5d37e243d6f388b1967776a5dfd16ef228f19c5bdd6d2820f145c5ed3c3dcdc00` +- `da16dfc6d0db04e3378dedce4f0250792646e53408a9116e6d5e1651a4ed692d257e1f7b107cdc40fe6e47257d9c189b0d66a83991d67459608ea1807a9a9b04` authorizer world: ``` @@ -919,7 +919,7 @@ allow if true; revocation ids: - `c46d071ff3f33434223c8305fdad529f62bf78bb5d9cbfc2a345d4bca6bf314014840e18ba353f86fdb9073d58b12b8c872ac1f8e593c2e9064b90f6c2ede006` -- `a0c4c163a0b3ca406df4ece3d1371356190df04208eccef72f77e875ed0531b5d37e243d6f388b1967776a5dfd16ef228f19c5bdd6d2820f145c5ed3c3dcdc00` +- `da16dfc6d0db04e3378dedce4f0250792646e53408a9116e6d5e1651a4ed692d257e1f7b107cdc40fe6e47257d9c189b0d66a83991d67459608ea1807a9a9b04` authorizer world: ``` @@ -1246,17 +1246,17 @@ check if 2020-12-04T09:46:41Z >= 2019-12-04T09:46:41Z; check if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z; check if 2020-12-04T09:46:41Z === 2020-12-04T09:46:41Z; check if hex:12ab === hex:12ab; -check if [1, 2].contains(2); -check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z); -check if [false, true].contains(true); -check if ["abc", "def"].contains("abc"); -check if [hex:12ab, hex:34de].contains(hex:34de); -check if [1, 2].contains([2]); -check if [1, 2] === [1, 2]; -check if [1, 2].intersection([2, 3]) === [2]; -check if [1, 2].union([2, 3]) === [1, 2, 3]; -check if [1, 2, 3].intersection([1, 2]).contains(1); -check if [1, 2, 3].intersection([1, 2]).length() === 2; +check if {1, 2}.contains(2); +check if {2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z}.contains(2020-12-04T09:46:41Z); +check if {false, true}.contains(true); +check if {"abc", "def"}.contains("abc"); +check if {hex:12ab, hex:34de}.contains(hex:34de); +check if {1, 2}.contains({2}); +check if {1, 2} === {1, 2}; +check if {1, 2}.intersection({2, 3}) === {2}; +check if {1, 2}.union({2, 3}) === {1, 2, 3}; +check if {1, 2, 3}.intersection({1, 2}).contains(1); +check if {1, 2, 3}.intersection({1, 2}).length() === 2; ``` ### validation @@ -1303,21 +1303,21 @@ World { "check if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z", "check if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z", "check if 3 === 3", - "check if [\"abc\", \"def\"].contains(\"abc\")", - "check if [1, 2, 3].intersection([1, 2]).contains(1)", - "check if [1, 2, 3].intersection([1, 2]).length() === 2", - "check if [1, 2] === [1, 2]", - "check if [1, 2].contains(2)", - "check if [1, 2].contains([2])", - "check if [1, 2].intersection([2, 3]) === [2]", - "check if [1, 2].union([2, 3]) === [1, 2, 3]", - "check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z)", - "check if [false, true].contains(true)", - "check if [hex:12ab, hex:34de].contains(hex:34de)", "check if false === false", "check if hex:12ab === hex:12ab", "check if true", "check if true === true", + "check if {\"abc\", \"def\"}.contains(\"abc\")", + "check if {1, 2, 3}.intersection({1, 2}).contains(1)", + "check if {1, 2, 3}.intersection({1, 2}).length() === 2", + "check if {1, 2} === {1, 2}", + "check if {1, 2}.contains(2)", + "check if {1, 2}.contains({2})", + "check if {1, 2}.intersection({2, 3}) === {2}", + "check if {1, 2}.union({2, 3}) === {1, 2, 3}", + "check if {2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z}.contains(2020-12-04T09:46:41Z)", + "check if {false, true}.contains(true)", + "check if {hex:12ab, hex:34de}.contains(hex:34de)", ], }, ] @@ -1916,7 +1916,7 @@ allow if true; ``` revocation ids: -- `c456817012e1d523c6d145b6d6a3475d9f7dd4383c535454ff3f745ecf4234984ce09b9dec0551f3d783abe850f826ce43b12f1fd91999a4753a56ecf4c56d0d` +- `899e1fa26d72b860fa6a6e6d58e71cc873230260dcb41d3390e0703c6e134d955defbd0741c23272ac6e6abb2066a23cff2fe815dc5e5bfd712d177cf74ee108` authorizer world: ``` @@ -1971,7 +1971,7 @@ allow if true; ``` revocation ids: -- `c456817012e1d523c6d145b6d6a3475d9f7dd4383c535454ff3f745ecf4234984ce09b9dec0551f3d783abe850f826ce43b12f1fd91999a4753a56ecf4c56d0d` +- `899e1fa26d72b860fa6a6e6d58e71cc873230260dcb41d3390e0703c6e134d955defbd0741c23272ac6e6abb2066a23cff2fe815dc5e5bfd712d177cf74ee108` authorizer world: ``` @@ -2325,7 +2325,7 @@ allow if true; ``` revocation ids: -- `117fa653744c859561555e6a6f5990e3a8e7817f91b87aa6991b6d64297158b4e884c92d10f49f74c96069df722aa676839b72751ca9d1fe83a7025b591de00b` +- `04f9b08f5cf677aa890fd830a4acc2a0ec7d4c9e2657d65ac691ae6512b549184fd7c6deaf17c446f12324a1c454fe373290fe8981bae69cc6054de7312da00f` authorizer world: ``` @@ -2776,11 +2776,11 @@ check if false || true; check if (true || false) && true; check if !(false && "x".intersection("x")); check if true || "x".intersection("x"); -check if [1, 2, 3].all($p -> $p > 0); -check if ![1, 2, 3].all($p -> $p == 2); -check if [1, 2, 3].any($p -> $p > 2); -check if ![1, 2, 3].any($p -> $p > 3); -check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q)); +check if {1, 2, 3}.all($p -> $p > 0); +check if !{1, 2, 3}.all($p -> $p == 2); +check if {1, 2, 3}.any($p -> $p > 2); +check if !{1, 2, 3}.any($p -> $p > 3); +check if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q)); ``` ### validation @@ -2805,15 +2805,15 @@ World { ), checks: [ "check if !(false && \"x\".intersection(\"x\"))", - "check if ![1, 2, 3].all($p -> $p == 2)", - "check if ![1, 2, 3].any($p -> $p > 3)", "check if !false && true", + "check if !{1, 2, 3}.all($p -> $p == 2)", + "check if !{1, 2, 3}.any($p -> $p > 3)", "check if (true || false) && true", - "check if [1, 2, 3].all($p -> $p > 0)", - "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", - "check if [1, 2, 3].any($p -> $p > 2)", "check if false || true", "check if true || \"x\".intersection(\"x\")", + "check if {1, 2, 3}.all($p -> $p > 0)", + "check if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q))", + "check if {1, 2, 3}.any($p -> $p > 2)", ], }, ] @@ -2846,15 +2846,15 @@ World { ), checks: [ "check if !(false && \"x\".intersection(\"x\"))", - "check if ![1, 2, 3].all($p -> $p == 2)", - "check if ![1, 2, 3].any($p -> $p > 3)", "check if !false && true", + "check if !{1, 2, 3}.all($p -> $p == 2)", + "check if !{1, 2, 3}.any($p -> $p > 3)", "check if (true || false) && true", - "check if [1, 2, 3].all($p -> $p > 0)", - "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", - "check if [1, 2, 3].any($p -> $p > 2)", "check if false || true", "check if true || \"x\".intersection(\"x\")", + "check if {1, 2, 3}.all($p -> $p > 0)", + "check if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q))", + "check if {1, 2, 3}.any($p -> $p > 2)", ], }, ] @@ -2866,3 +2866,91 @@ World { result: `Err(Execution(ShadowedVariable))` + +------------------------------ + +## test array and map operations (v5 blocks): test033_array_map.bc +### token + +authority: +symbols: ["a", "b", "c", "p", "d", "A", "kv", "id", "roles"] + +public keys: [] + +``` +check if [1, 2, 1].length() == 3; +check if ["a", "b"] != [1, 2, 3]; +check if ["a", "b"] == ["a", "b"]; +check if ["a", "b", "c"].contains("c"); +check if [1, 2, 3].starts_with([1, 2]); +check if [4, 5, 6].ends_with([6]); +check if [1, 2, "a"].get(2) == "a"; +check if [1, 2].get(3) == null; +check if [1, 2, 3].all($p -> $p > 0); +check if [1, 2, 3].any($p -> $p > 2); +check if {"a": 1, "b": 2, "c": 3, "d": 4}.length() == 4; +check if {1: "a", 2: "b"} != {"a": 1, "b": 2}; +check if {1: "a", 2: "b"} == {1: "a", 2: "b"}; +check if {"a": 1, "b": 2, "c": 3, "d": 4}.contains("d"); +check if {1: "A", "a": 1, "b": 2}.get("a") == 1; +check if {1: "A", "a": 1, "b": 2}.get(1) == "A"; +check if {1: "A", "a": 1, "b": 2}.get("c") == null; +check if {1: "A", "a": 1, "b": 2}.get(2) == null; +check if {"a": 1, "b": 2}.all($kv -> $kv.get(0) != "c" && $kv.get(1) < 3); +check if {1: "A", "a": 1, "b": 2}.any($kv -> $kv.get(0) == 1 && $kv.get(1) == "A"); +check if {"user": {"id": 1, "roles": ["admin"]}}.get("user").get("roles").contains("admin"); +``` + +### validation + +authorizer code: +``` +allow if true; +``` + +revocation ids: +- `7096e2ad9ad5dcae778cea1cee800ffc38017196e56aed693810d0933bcecc804a723768c3b494fa23d99be59ca3588bfa806e3fe2dac29d0ca9e452b69ead09` + +authorizer world: +``` +World { + facts: [] + rules: [] + checks: [ + Checks { + origin: Some( + 0, + ), + checks: [ + "check if [\"a\", \"b\", \"c\"].contains(\"c\")", + "check if [\"a\", \"b\"] != [1, 2, 3]", + "check if [\"a\", \"b\"] == [\"a\", \"b\"]", + "check if [1, 2, \"a\"].get(2) == \"a\"", + "check if [1, 2, 1].length() == 3", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if [1, 2, 3].starts_with([1, 2])", + "check if [1, 2].get(3) == null", + "check if [4, 5, 6].ends_with([6])", + "check if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.contains(\"d\")", + "check if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.length() == 4", + "check if {\"a\": 1, \"b\": 2}.all($kv -> $kv.get(0) != \"c\" && $kv.get(1) < 3)", + "check if {\"user\": {\"id\": 1, \"roles\": [\"admin\"]}}.get(\"user\").get(\"roles\").contains(\"admin\")", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.any($kv -> $kv.get(0) == 1 && $kv.get(1) == \"A\")", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"a\") == 1", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"c\") == null", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(1) == \"A\"", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(2) == null", + "check if {1: \"a\", 2: \"b\"} != {\"a\": 1, \"b\": 2}", + "check if {1: \"a\", 2: \"b\"} == {1: \"a\", 2: \"b\"}", + ], + }, +] + policies: [ + "allow if true", +] +} +``` + +result: `Ok(0)` + diff --git a/biscuit-auth/samples/samples.json b/biscuit-auth/samples/samples.json index 82856d84..3fb9fa30 100644 --- a/biscuit-auth/samples/samples.json +++ b/biscuit-auth/samples/samples.json @@ -924,7 +924,7 @@ "authorizer_code": "resource(\"file1\");\ntime(2020-12-21T09:23:12Z);\n\nallow if true;\n", "revocation_ids": [ "c46d071ff3f33434223c8305fdad529f62bf78bb5d9cbfc2a345d4bca6bf314014840e18ba353f86fdb9073d58b12b8c872ac1f8e593c2e9064b90f6c2ede006", - "a0c4c163a0b3ca406df4ece3d1371356190df04208eccef72f77e875ed0531b5d37e243d6f388b1967776a5dfd16ef228f19c5bdd6d2820f145c5ed3c3dcdc00" + "da16dfc6d0db04e3378dedce4f0250792646e53408a9116e6d5e1651a4ed692d257e1f7b107cdc40fe6e47257d9c189b0d66a83991d67459608ea1807a9a9b04" ] }, "file2": { @@ -993,7 +993,7 @@ "authorizer_code": "resource(\"file2\");\ntime(2020-12-21T09:23:12Z);\n\nallow if true;\n", "revocation_ids": [ "c46d071ff3f33434223c8305fdad529f62bf78bb5d9cbfc2a345d4bca6bf314014840e18ba353f86fdb9073d58b12b8c872ac1f8e593c2e9064b90f6c2ede006", - "a0c4c163a0b3ca406df4ece3d1371356190df04208eccef72f77e875ed0531b5d37e243d6f388b1967776a5dfd16ef228f19c5bdd6d2820f145c5ed3c3dcdc00" + "da16dfc6d0db04e3378dedce4f0250792646e53408a9116e6d5e1651a4ed692d257e1f7b107cdc40fe6e47257d9c189b0d66a83991d67459608ea1807a9a9b04" ] } } @@ -1245,7 +1245,7 @@ ], "public_keys": [], "external_key": null, - "code": "check if true;\ncheck if !false;\ncheck if true === true;\ncheck if false === false;\ncheck if 1 < 2;\ncheck if 2 > 1;\ncheck if 1 <= 2;\ncheck if 1 <= 1;\ncheck if 2 >= 1;\ncheck if 2 >= 2;\ncheck if 3 === 3;\ncheck if 1 + 2 * 3 - 4 / 2 === 5;\ncheck if \"hello world\".starts_with(\"hello\"), \"hello world\".ends_with(\"world\");\ncheck if \"aaabde\".matches(\"a*c?.e\");\ncheck if \"aaabde\".contains(\"abd\");\ncheck if \"aaabde\" === \"aaa\" + \"b\" + \"de\";\ncheck if \"abcD12\" === \"abcD12\";\ncheck if \"abcD12\".length() === 6;\ncheck if \"é\".length() === 2;\ncheck if 2019-12-04T09:46:41Z < 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z > 2019-12-04T09:46:41Z;\ncheck if 2019-12-04T09:46:41Z <= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2019-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z === 2020-12-04T09:46:41Z;\ncheck if hex:12ab === hex:12ab;\ncheck if [1, 2].contains(2);\ncheck if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z);\ncheck if [false, true].contains(true);\ncheck if [\"abc\", \"def\"].contains(\"abc\");\ncheck if [hex:12ab, hex:34de].contains(hex:34de);\ncheck if [1, 2].contains([2]);\ncheck if [1, 2] === [1, 2];\ncheck if [1, 2].intersection([2, 3]) === [2];\ncheck if [1, 2].union([2, 3]) === [1, 2, 3];\ncheck if [1, 2, 3].intersection([1, 2]).contains(1);\ncheck if [1, 2, 3].intersection([1, 2]).length() === 2;\n" + "code": "check if true;\ncheck if !false;\ncheck if true === true;\ncheck if false === false;\ncheck if 1 < 2;\ncheck if 2 > 1;\ncheck if 1 <= 2;\ncheck if 1 <= 1;\ncheck if 2 >= 1;\ncheck if 2 >= 2;\ncheck if 3 === 3;\ncheck if 1 + 2 * 3 - 4 / 2 === 5;\ncheck if \"hello world\".starts_with(\"hello\"), \"hello world\".ends_with(\"world\");\ncheck if \"aaabde\".matches(\"a*c?.e\");\ncheck if \"aaabde\".contains(\"abd\");\ncheck if \"aaabde\" === \"aaa\" + \"b\" + \"de\";\ncheck if \"abcD12\" === \"abcD12\";\ncheck if \"abcD12\".length() === 6;\ncheck if \"é\".length() === 2;\ncheck if 2019-12-04T09:46:41Z < 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z > 2019-12-04T09:46:41Z;\ncheck if 2019-12-04T09:46:41Z <= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2019-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z === 2020-12-04T09:46:41Z;\ncheck if hex:12ab === hex:12ab;\ncheck if {1, 2}.contains(2);\ncheck if {2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z}.contains(2020-12-04T09:46:41Z);\ncheck if {false, true}.contains(true);\ncheck if {\"abc\", \"def\"}.contains(\"abc\");\ncheck if {hex:12ab, hex:34de}.contains(hex:34de);\ncheck if {1, 2}.contains({2});\ncheck if {1, 2} === {1, 2};\ncheck if {1, 2}.intersection({2, 3}) === {2};\ncheck if {1, 2}.union({2, 3}) === {1, 2, 3};\ncheck if {1, 2, 3}.intersection({1, 2}).contains(1);\ncheck if {1, 2, 3}.intersection({1, 2}).length() === 2;\n" } ], "validations": { @@ -1280,21 +1280,21 @@ "check if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z", "check if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z", "check if 3 === 3", - "check if [\"abc\", \"def\"].contains(\"abc\")", - "check if [1, 2, 3].intersection([1, 2]).contains(1)", - "check if [1, 2, 3].intersection([1, 2]).length() === 2", - "check if [1, 2] === [1, 2]", - "check if [1, 2].contains(2)", - "check if [1, 2].contains([2])", - "check if [1, 2].intersection([2, 3]) === [2]", - "check if [1, 2].union([2, 3]) === [1, 2, 3]", - "check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z)", - "check if [false, true].contains(true)", - "check if [hex:12ab, hex:34de].contains(hex:34de)", "check if false === false", "check if hex:12ab === hex:12ab", "check if true", - "check if true === true" + "check if true === true", + "check if {\"abc\", \"def\"}.contains(\"abc\")", + "check if {1, 2, 3}.intersection({1, 2}).contains(1)", + "check if {1, 2, 3}.intersection({1, 2}).length() === 2", + "check if {1, 2} === {1, 2}", + "check if {1, 2}.contains(2)", + "check if {1, 2}.contains({2})", + "check if {1, 2}.intersection({2, 3}) === {2}", + "check if {1, 2}.union({2, 3}) === {1, 2, 3}", + "check if {2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z}.contains(2020-12-04T09:46:41Z)", + "check if {false, true}.contains(true)", + "check if {hex:12ab, hex:34de}.contains(hex:34de)" ] } ], @@ -1857,7 +1857,7 @@ }, "authorizer_code": "operation(\"A\");\noperation(\"B\");\n\nallow if true;\n", "revocation_ids": [ - "c456817012e1d523c6d145b6d6a3475d9f7dd4383c535454ff3f745ecf4234984ce09b9dec0551f3d783abe850f826ce43b12f1fd91999a4753a56ecf4c56d0d" + "899e1fa26d72b860fa6a6e6d58e71cc873230260dcb41d3390e0703c6e134d955defbd0741c23272ac6e6abb2066a23cff2fe815dc5e5bfd712d177cf74ee108" ] }, "A, invalid": { @@ -1916,7 +1916,7 @@ }, "authorizer_code": "operation(\"A\");\noperation(\"invalid\");\n\nallow if true;\n", "revocation_ids": [ - "c456817012e1d523c6d145b6d6a3475d9f7dd4383c535454ff3f745ecf4234984ce09b9dec0551f3d783abe850f826ce43b12f1fd91999a4753a56ecf4c56d0d" + "899e1fa26d72b860fa6a6e6d58e71cc873230260dcb41d3390e0703c6e134d955defbd0741c23272ac6e6abb2066a23cff2fe815dc5e5bfd712d177cf74ee108" ] } } @@ -2170,7 +2170,7 @@ }, "authorizer_code": "allow if true;\n", "revocation_ids": [ - "117fa653744c859561555e6a6f5990e3a8e7817f91b87aa6991b6d64297158b4e884c92d10f49f74c96069df722aa676839b72751ca9d1fe83a7025b591de00b" + "04f9b08f5cf677aa890fd830a4acc2a0ec7d4c9e2657d65ac691ae6512b549184fd7c6deaf17c446f12324a1c454fe373290fe8981bae69cc6054de7312da00f" ] } } @@ -2616,7 +2616,7 @@ ], "public_keys": [], "external_key": null, - "code": "check if !false && true;\ncheck if false || true;\ncheck if (true || false) && true;\ncheck if !(false && \"x\".intersection(\"x\"));\ncheck if true || \"x\".intersection(\"x\");\ncheck if [1, 2, 3].all($p -> $p > 0);\ncheck if ![1, 2, 3].all($p -> $p == 2);\ncheck if [1, 2, 3].any($p -> $p > 2);\ncheck if ![1, 2, 3].any($p -> $p > 3);\ncheck if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q));\n" + "code": "check if !false && true;\ncheck if false || true;\ncheck if (true || false) && true;\ncheck if !(false && \"x\".intersection(\"x\"));\ncheck if true || \"x\".intersection(\"x\");\ncheck if {1, 2, 3}.all($p -> $p > 0);\ncheck if !{1, 2, 3}.all($p -> $p == 2);\ncheck if {1, 2, 3}.any($p -> $p > 2);\ncheck if !{1, 2, 3}.any($p -> $p > 3);\ncheck if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q));\n" } ], "validations": { @@ -2629,15 +2629,15 @@ "origin": 0, "checks": [ "check if !(false && \"x\".intersection(\"x\"))", - "check if ![1, 2, 3].all($p -> $p == 2)", - "check if ![1, 2, 3].any($p -> $p > 3)", "check if !false && true", + "check if !{1, 2, 3}.all($p -> $p == 2)", + "check if !{1, 2, 3}.any($p -> $p > 3)", "check if (true || false) && true", - "check if [1, 2, 3].all($p -> $p > 0)", - "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", - "check if [1, 2, 3].any($p -> $p > 2)", "check if false || true", - "check if true || \"x\".intersection(\"x\")" + "check if true || \"x\".intersection(\"x\")", + "check if {1, 2, 3}.all($p -> $p > 0)", + "check if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q))", + "check if {1, 2, 3}.any($p -> $p > 2)" ] } ], @@ -2662,15 +2662,15 @@ "origin": 0, "checks": [ "check if !(false && \"x\".intersection(\"x\"))", - "check if ![1, 2, 3].all($p -> $p == 2)", - "check if ![1, 2, 3].any($p -> $p > 3)", "check if !false && true", + "check if !{1, 2, 3}.all($p -> $p == 2)", + "check if !{1, 2, 3}.any($p -> $p > 3)", "check if (true || false) && true", - "check if [1, 2, 3].all($p -> $p > 0)", - "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", - "check if [1, 2, 3].any($p -> $p > 2)", "check if false || true", - "check if true || \"x\".intersection(\"x\")" + "check if true || \"x\".intersection(\"x\")", + "check if {1, 2, 3}.all($p -> $p > 0)", + "check if {1, 2, 3}.any($p -> $p > 1 && {3, 4, 5}.any($q -> $p == $q))", + "check if {1, 2, 3}.any($p -> $p > 2)" ] } ], @@ -2689,6 +2689,74 @@ ] } } + }, + { + "title": "test array and map operations (v5 blocks)", + "filename": "test033_array_map.bc", + "token": [ + { + "symbols": [ + "a", + "b", + "c", + "p", + "d", + "A", + "kv", + "id", + "roles" + ], + "public_keys": [], + "external_key": null, + "code": "check if [1, 2, 1].length() == 3;\ncheck if [\"a\", \"b\"] != [1, 2, 3];\ncheck if [\"a\", \"b\"] == [\"a\", \"b\"];\ncheck if [\"a\", \"b\", \"c\"].contains(\"c\");\ncheck if [1, 2, 3].starts_with([1, 2]);\ncheck if [4, 5, 6].ends_with([6]);\ncheck if [1, 2, \"a\"].get(2) == \"a\";\ncheck if [1, 2].get(3) == null;\ncheck if [1, 2, 3].all($p -> $p > 0);\ncheck if [1, 2, 3].any($p -> $p > 2);\ncheck if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.length() == 4;\ncheck if {1: \"a\", 2: \"b\"} != {\"a\": 1, \"b\": 2};\ncheck if {1: \"a\", 2: \"b\"} == {1: \"a\", 2: \"b\"};\ncheck if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.contains(\"d\");\ncheck if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"a\") == 1;\ncheck if {1: \"A\", \"a\": 1, \"b\": 2}.get(1) == \"A\";\ncheck if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"c\") == null;\ncheck if {1: \"A\", \"a\": 1, \"b\": 2}.get(2) == null;\ncheck if {\"a\": 1, \"b\": 2}.all($kv -> $kv.get(0) != \"c\" && $kv.get(1) < 3);\ncheck if {1: \"A\", \"a\": 1, \"b\": 2}.any($kv -> $kv.get(0) == 1 && $kv.get(1) == \"A\");\ncheck if {\"user\": {\"id\": 1, \"roles\": [\"admin\"]}}.get(\"user\").get(\"roles\").contains(\"admin\");\n" + } + ], + "validations": { + "": { + "world": { + "facts": [], + "rules": [], + "checks": [ + { + "origin": 0, + "checks": [ + "check if [\"a\", \"b\", \"c\"].contains(\"c\")", + "check if [\"a\", \"b\"] != [1, 2, 3]", + "check if [\"a\", \"b\"] == [\"a\", \"b\"]", + "check if [1, 2, \"a\"].get(2) == \"a\"", + "check if [1, 2, 1].length() == 3", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if [1, 2, 3].starts_with([1, 2])", + "check if [1, 2].get(3) == null", + "check if [4, 5, 6].ends_with([6])", + "check if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.contains(\"d\")", + "check if {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}.length() == 4", + "check if {\"a\": 1, \"b\": 2}.all($kv -> $kv.get(0) != \"c\" && $kv.get(1) < 3)", + "check if {\"user\": {\"id\": 1, \"roles\": [\"admin\"]}}.get(\"user\").get(\"roles\").contains(\"admin\")", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.any($kv -> $kv.get(0) == 1 && $kv.get(1) == \"A\")", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"a\") == 1", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(\"c\") == null", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(1) == \"A\"", + "check if {1: \"A\", \"a\": 1, \"b\": 2}.get(2) == null", + "check if {1: \"a\", 2: \"b\"} != {\"a\": 1, \"b\": 2}", + "check if {1: \"a\", 2: \"b\"} == {1: \"a\", 2: \"b\"}" + ] + } + ], + "policies": [ + "allow if true" + ] + }, + "result": { + "Ok": 0 + }, + "authorizer_code": "allow if true;\n", + "revocation_ids": [ + "7096e2ad9ad5dcae778cea1cee800ffc38017196e56aed693810d0933bcecc804a723768c3b494fa23d99be59ca3588bfa806e3fe2dac29d0ca9e452b69ead09" + ] + } + } } ] } diff --git a/biscuit-auth/samples/test013_block_rules.bc b/biscuit-auth/samples/test013_block_rules.bc index 149b4ee8..2f2d4d99 100644 Binary files a/biscuit-auth/samples/test013_block_rules.bc and b/biscuit-auth/samples/test013_block_rules.bc differ diff --git a/biscuit-auth/samples/test025_check_all.bc b/biscuit-auth/samples/test025_check_all.bc index 221df2ca..8930a445 100644 Binary files a/biscuit-auth/samples/test025_check_all.bc and b/biscuit-auth/samples/test025_check_all.bc differ diff --git a/biscuit-auth/samples/test028_expressions_v4.bc b/biscuit-auth/samples/test028_expressions_v4.bc index c34d7a10..bb28a285 100644 Binary files a/biscuit-auth/samples/test028_expressions_v4.bc and b/biscuit-auth/samples/test028_expressions_v4.bc differ diff --git a/biscuit-auth/samples/test033_array_map.bc b/biscuit-auth/samples/test033_array_map.bc new file mode 100644 index 00000000..b4cd4983 Binary files /dev/null and b/biscuit-auth/samples/test033_array_map.bc differ diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index c07e42e4..51b73071 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -1,9 +1,10 @@ use crate::error; -use super::Term; +use super::{MapKey, Term}; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; use std::collections::{HashMap, HashSet}; +use std::convert::TryFrom; #[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Expression { @@ -41,6 +42,9 @@ impl Unary { .ok_or(error::Expression::UnknownSymbol(i)), (Unary::Length, Term::Bytes(s)) => Ok(Term::Integer(s.len() as i64)), (Unary::Length, Term::Set(s)) => Ok(Term::Integer(s.len() as i64)), + (Unary::Length, Term::Array(a)) => Ok(Term::Integer(a.len() as i64)), + (Unary::Length, Term::Map(m)) => Ok(Term::Integer(m.len() as i64)), + _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -87,6 +91,7 @@ pub enum Binary { LazyOr, All, Any, + Get, } impl Binary { @@ -99,6 +104,7 @@ impl Binary { symbols: &mut TemporarySymbolTable, ) -> Result { match (self, left, params) { + // boolean (Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)), (Binary::LazyOr, Term::Bool(false), []) => { let e = Expression { ops: right.clone() }; @@ -109,6 +115,8 @@ impl Binary { let e = Expression { ops: right.clone() }; e.evaluate(values, symbols) } + + // set (Binary::All, Term::Set(set_values), [param]) => { for value in set_values.iter() { values.insert(*param, value.clone()); @@ -137,6 +145,76 @@ impl Binary { } Ok(Term::Bool(false)) } + + // array + (Binary::All, Term::Array(array), [param]) => { + for value in array.iter() { + values.insert(*param, value.clone()); + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(true) => {} + Term::Bool(false) => return Ok(Term::Bool(false)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(true)) + } + (Binary::Any, Term::Array(array), [param]) => { + for value in array.iter() { + values.insert(*param, value.clone()); + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(false) => {} + Term::Bool(true) => return Ok(Term::Bool(true)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(false)) + } + + //map + (Binary::All, Term::Map(map), [param]) => { + for (key, value) in map.iter() { + let key = match key { + MapKey::Integer(i) => Term::Integer(*i), + MapKey::Str(i) => Term::Str(*i), + }; + values.insert(*param, Term::Array(vec![key, value.clone()])); + + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(true) => {} + Term::Bool(false) => return Ok(Term::Bool(false)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(true)) + } + (Binary::Any, Term::Map(map), [param]) => { + for (key, value) in map.iter() { + let key = match key { + MapKey::Integer(i) => Term::Integer(*i), + MapKey::Str(i) => Term::Str(*i), + }; + values.insert(*param, Term::Array(vec![key, value.clone()])); + + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(false) => {} + Term::Bool(true) => return Ok(Term::Bool(true)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(false)) + } (_, _, _) => Err(error::Expression::InvalidType), } } @@ -303,6 +381,46 @@ impl Binary { (Binary::HeterogeneousNotEqual, Term::Null, _) => Ok(Term::Bool(true)), (Binary::HeterogeneousNotEqual, _, Term::Null) => Ok(Term::Bool(true)), + // array + (Binary::Equal | Binary::HeterogeneousEqual, Term::Array(i), Term::Array(j)) => { + Ok(Term::Bool(i == j)) + } + (Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Array(i), Term::Array(j)) => { + Ok(Term::Bool(i != j)) + } + (Binary::Contains, Term::Array(i), j) => { + Ok(Term::Bool(i.iter().any(|elem| elem == &j))) + } + (Binary::Prefix, Term::Array(i), Term::Array(j)) => Ok(Term::Bool(i.starts_with(&j))), + (Binary::Suffix, Term::Array(i), Term::Array(j)) => Ok(Term::Bool(i.ends_with(&j))), + (Binary::Get, Term::Array(i), Term::Integer(index)) => Ok(TryFrom::try_from(index) + .ok() + .and_then(|index: usize| i.get(index).cloned()) + .unwrap_or(Term::Null)), + + // map + (Binary::Equal | Binary::HeterogeneousEqual, Term::Map(i), Term::Map(j)) => { + Ok(Term::Bool(i == j)) + } + (Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Map(i), Term::Map(j)) => { + Ok(Term::Bool(i != j)) + } + (Binary::Contains, Term::Map(i), j) => { + Ok(Term::Bool(i.iter().any(|elem| match (elem.0, &j) { + (super::MapKey::Integer(k), Term::Integer(l)) => k == l, + (super::MapKey::Str(k), Term::Str(l)) => k == l, + _ => false, + }))) + } + (Binary::Get, Term::Map(m), Term::Integer(i)) => match m.get(&MapKey::Integer(i)) { + Some(term) => Ok(term.clone()), + None => Ok(Term::Null), + }, + (Binary::Get, Term::Map(m), Term::Str(i)) => match m.get(&MapKey::Str(i)) { + Some(term) => Ok(term.clone()), + None => Ok(Term::Null), + }, + (Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)), (Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)), @@ -342,6 +460,7 @@ impl Binary { Binary::LazyOr => format!("{left} || {right}"), Binary::All => format!("{left}.all({right})"), Binary::Any => format!("{left}.any({right})"), + Binary::Get => format!("{left}.get({right})"), } } } @@ -477,7 +596,7 @@ mod tests { use std::collections::BTreeSet; use super::*; - use crate::datalog::{SymbolTable, TemporarySymbolTable}; + use crate::datalog::{MapKey, SymbolTable, TemporarySymbolTable}; #[test] fn negate() { @@ -1024,4 +1143,416 @@ mod tests { let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols); assert_eq!(res2, Err(error::Expression::ShadowedVariable)); } + + #[test] + fn array() { + let symbols = SymbolTable::new(); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + let ops = vec![ + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Value(Term::Array(vec![Term::Integer(0)])), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + let ops = vec![ + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Value(Term::Integer(2)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(1)])), + Op::Binary(Binary::Prefix), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Array(vec![Term::Integer(2), Term::Integer(1)])), + Op::Binary(Binary::Prefix), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Array(vec![Term::Integer(1), Term::Integer(2)])), + Op::Binary(Binary::Suffix), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Array(vec![Term::Integer(0), Term::Integer(2)])), + Op::Binary(Binary::Suffix), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + // get + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Integer(1))); + + // get out of bounds + let ops = vec![ + Op::Value(Term::Array(vec![ + Term::Integer(0), + Term::Integer(1), + Term::Integer(2), + ])), + Op::Value(Term::Integer(3)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Null)); + + // all + let p = tmp_symbols.insert("param") as u32; + let ops1 = vec![ + Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::GreaterThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + + // any + let ops1 = vec![ + Op::Value(Term::Array([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Equal), + ], + ), + Op::Binary(Binary::Any), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(false)); + } + + #[test] + fn map() { + let mut symbols = SymbolTable::new(); + let p = symbols.insert("param") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Map( + [ + (MapKey::Str(2), Term::Integer(1)), + (MapKey::Str(1), Term::Integer(0)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Map( + [(MapKey::Str(1), Term::Integer(0))] + .iter() + .cloned() + .collect(), + )), + Op::Binary(Binary::Equal), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Str(1)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Contains), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(false))); + + // get + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Integer(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Str(1)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Integer(0))); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Integer(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Integer(2)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Integer(1))); + + // get non existing key + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Null)); + + let ops = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Value(Term::Str(3)), + Op::Binary(Binary::Get), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Null)); + + // all + let ops1 = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::Get), + Op::Value(Term::Integer(2)), + Op::Binary(Binary::LessThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + + // any + let ops1 = vec![ + Op::Value(Term::Map( + [ + (MapKey::Str(1), Term::Integer(0)), + (MapKey::Str(2), Term::Integer(1)), + ] + .iter() + .cloned() + .collect(), + )), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Get), + Op::Value(Term::Str(1)), + Op::Binary(Binary::Equal), + ], + ), + Op::Binary(Binary::Any), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + } } diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index 8b954e6b..068787be 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -4,7 +4,7 @@ use crate::error::Execution; use crate::time::Instant; use crate::token::{Scope, MIN_SCHEMA_VERSION}; use crate::{builder, error}; -use std::collections::{BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::convert::AsRef; use std::fmt; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -26,6 +26,14 @@ pub enum Term { Bool(bool), Set(BTreeSet), Null, + Array(Vec), + Map(BTreeMap), +} + +#[derive(Debug, Clone, PartialEq, Hash, Eq, PartialOrd, Ord)] +pub enum MapKey { + Integer(i64), + Str(SymbolIndex), } impl From<&Term> for Term { @@ -39,6 +47,8 @@ impl From<&Term> for Term { Term::Bool(ref b) => Term::Bool(*b), Term::Set(ref s) => Term::Set(s.clone()), Term::Null => Term::Null, + Term::Array(ref a) => Term::Array(a.clone()), + Term::Map(m) => Term::Map(m.clone()), } } } @@ -564,6 +574,8 @@ pub fn match_preds(rule_pred: &Predicate, fact_pred: &Predicate) -> bool { (Term::Bool(i), Term::Bool(j)) => i == j, (Term::Null, Term::Null) => true, (Term::Set(i), Term::Set(j)) => i == j, + (Term::Array(i), Term::Array(j)) => i == j, + (Term::Map(i), Term::Map(j)) => i == j, _ => false, }) } diff --git a/biscuit-auth/src/datalog/symbol.rs b/biscuit-auth/src/datalog/symbol.rs index ea1d10a5..e9080cf7 100644 --- a/biscuit-auth/src/datalog/symbol.rs +++ b/biscuit-auth/src/datalog/symbol.rs @@ -202,9 +202,34 @@ impl SymbolTable { .iter() .map(|term| self.print_term(term)) .collect::>(); - format!("[{}]", terms.join(", ")) + format!("{{{}}}", terms.join(", ")) } Term::Null => "null".to_string(), + Term::Array(a) => { + let terms = a + .iter() + .map(|term| self.print_term(term)) + .collect::>(); + format!("[{}]", terms.join(", ")) + } + Term::Map(m) => { + let terms = m + .iter() + .map(|(key, term)| match key { + crate::datalog::MapKey::Integer(i) => { + format!("{}: {}", i, self.print_term(term)) + } + crate::datalog::MapKey::Str(s) => { + format!( + "\"{}\": {}", + self.print_symbol_default(*s as u64), + self.print_term(term) + ) + } + }) + .collect::>(); + format!("{{{}}}", terms.join(", ")) + } } } pub fn print_fact(&self, f: &Fact) -> String { diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index 590e0f7c..3798fc45 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -312,8 +312,10 @@ pub mod v2 { use crate::datalog::*; use crate::error; use crate::format::schema::Empty; + use crate::format::schema::MapEntry; use crate::token::Scope; use crate::token::MIN_SCHEMA_VERSION; + use std::collections::BTreeMap; use std::collections::BTreeSet; pub fn token_fact_to_proto_fact(input: &Fact) -> schema::FactV2 { @@ -523,6 +525,32 @@ pub mod v2 { Term::Null => schema::TermV2 { content: Some(Content::Null(Empty {})), }, + Term::Array(a) => schema::TermV2 { + content: Some(Content::Array(schema::Array { + array: a.iter().map(token_term_to_proto_id).collect(), + })), + }, + Term::Map(m) => schema::TermV2 { + content: Some(Content::Map(schema::Map { + entries: m + .iter() + .map(|(key, term)| { + let key = match key { + MapKey::Integer(i) => schema::MapKey { + content: Some(schema::map_key::Content::Integer(*i)), + }, + MapKey::Str(s) => schema::MapKey { + content: Some(schema::map_key::Content::String(*s)), + }, + }; + schema::MapEntry { + key, + value: token_term_to_proto_id(term), + } + }) + .collect(), + })), + }, } } @@ -561,6 +589,8 @@ pub mod v2 { )); } Some(Content::Null(_)) => 8, + Some(Content::Array(_)) => 9, + Some(Content::Map(_)) => 10, None => { return Err(error::Format::DeserializationError( "deserialization error: ID content enum is empty".to_string(), @@ -585,6 +615,34 @@ pub mod v2 { Ok(Term::Set(set)) } Some(Content::Null(_)) => Ok(Term::Null), + Some(Content::Array(a)) => { + let array = a + .array + .iter() + .map(proto_id_to_token_term) + .collect::>()?; + + Ok(Term::Array(array)) + } + Some(Content::Map(m)) => { + let mut map = BTreeMap::new(); + + for MapEntry { key, value } in m.entries.iter() { + let key = match key.content { + Some(schema::map_key::Content::Integer(i)) => MapKey::Integer(i), + Some(schema::map_key::Content::String(s)) => MapKey::Str(s), + None => { + return Err(error::Format::DeserializationError( + "deserialization error: ID content enum is empty".to_string(), + )) + } + }; + + map.insert(key, proto_id_to_token_term(&value)?); + } + + Ok(Term::Map(map)) + } } } @@ -634,6 +692,7 @@ pub mod v2 { Binary::LazyOr => Kind::LazyOr, Binary::All => Kind::All, Binary::Any => Kind::Any, + Binary::Get => Kind::Get, } as i32, }) } @@ -698,6 +757,7 @@ pub mod v2 { Some(op_binary::Kind::LazyOr) => Op::Binary(Binary::LazyOr), Some(op_binary::Kind::All) => Op::Binary(Binary::All), Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), + Some(op_binary::Kind::Get) => Op::Binary(Binary::Get), None => { return Err(error::Format::DeserializationError( "deserialization error: binary operation is empty".to_string(), diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index 349bfb41..dc7fa0b9 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -99,6 +99,8 @@ message TermV2 { bool bool = 6; TermSet set = 7; Empty null = 8; + Array array = 9; + Map map = 10; } } @@ -106,6 +108,26 @@ message TermSet { repeated TermV2 set = 1; } +message Array { + repeated TermV2 array = 1; +} + +message Map { + repeated MapEntry entries = 1; +} + +message MapEntry { + required MapKey key = 1; + required TermV2 value = 2; +} + +message MapKey { + oneof Content { + int64 integer = 1; + uint64 string = 2; + } +} + message ExpressionV2 { repeated Op ops = 1; } @@ -158,6 +180,7 @@ message OpBinary { LazyOr = 24; All = 25; Any = 26; + Get = 27; } required Kind kind = 1; diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index 58e7769a..8ad23ff5 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -139,7 +139,7 @@ pub struct PredicateV2 { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TermV2 { - #[prost(oneof="term_v2::Content", tags="1, 2, 3, 4, 5, 6, 7, 8")] + #[prost(oneof="term_v2::Content", tags="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")] pub content: ::core::option::Option, } /// Nested message and enum types in `TermV2`. @@ -162,6 +162,10 @@ pub mod term_v2 { Set(super::TermSet), #[prost(message, tag="8")] Null(super::Empty), + #[prost(message, tag="9")] + Array(super::Array), + #[prost(message, tag="10")] + Map(super::Map), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -170,6 +174,38 @@ pub struct TermSet { pub set: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct Array { + #[prost(message, repeated, tag="1")] + pub array: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Map { + #[prost(message, repeated, tag="1")] + pub entries: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MapEntry { + #[prost(message, required, tag="1")] + pub key: MapKey, + #[prost(message, required, tag="2")] + pub value: TermV2, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MapKey { + #[prost(oneof="map_key::Content", tags="1, 2")] + pub content: ::core::option::Option, +} +/// Nested message and enum types in `MapKey`. +pub mod map_key { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Content { + #[prost(int64, tag="1")] + Integer(i64), + #[prost(uint64, tag="2")] + String(u64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ExpressionV2 { #[prost(message, repeated, tag="1")] pub ops: ::prost::alloc::vec::Vec, @@ -245,6 +281,7 @@ pub mod op_binary { LazyOr = 24, All = 25, Any = 26, + Get = 27, } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index e63fc72e..44042121 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -7,6 +7,7 @@ use crate::token::builder_ext::BuilderExt; use biscuit_parser::parser::parse_block_source; use nom::Finish; use rand_core::{CryptoRng, RngCore}; +use std::collections::BTreeMap; use std::str::FromStr; use std::{ collections::{BTreeSet, HashMap}, @@ -430,6 +431,94 @@ pub enum Term { Set(BTreeSet), Parameter(String), Null, + Array(Vec), + Map(BTreeMap), +} + +impl Term { + fn extract_parameters(&self, parameters: &mut HashMap>) { + match self { + Term::Parameter(name) => { + parameters.insert(name.to_string(), None); + } + Term::Set(s) => { + for term in s { + term.extract_parameters(parameters); + } + } + Term::Array(a) => { + for term in a { + term.extract_parameters(parameters); + } + } + Term::Map(m) => { + for (key, term) in m { + if let MapKey::Parameter(name) = key { + parameters.insert(name.to_string(), None); + } + term.extract_parameters(parameters); + } + } + _ => {} + } + } + + fn apply_parameters(self, parameters: &HashMap>) -> Term { + match self { + Term::Parameter(name) => { + if let Some(Some(term)) = parameters.get(&name) { + term.clone() + } else { + Term::Parameter(name) + } + } + Term::Map(m) => Term::Map( + m.into_iter() + .map(|(key, term)| { + println!("will try to apply parameters on {key:?} -> {term:?}"); + ( + match key { + MapKey::Parameter(name) => { + if let Some(Some(key_term)) = parameters.get(&name) { + println!("found key term: {key_term}"); + match key_term { + Term::Integer(i) => MapKey::Integer(*i), + Term::Str(s) => MapKey::Str(s.clone()), + //FIXME: we should return an error + _ => MapKey::Parameter(name), + } + } else { + MapKey::Parameter(name) + } + } + _ => key, + }, + term.apply_parameters(parameters), + ) + }) + .collect(), + ), + Term::Array(array) => Term::Array( + array + .into_iter() + .map(|term| term.apply_parameters(parameters)) + .collect(), + ), + Term::Set(set) => Term::Set( + set.into_iter() + .map(|term| term.apply_parameters(parameters)) + .collect(), + ), + _ => self, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MapKey { + Integer(i64), + Str(String), + Parameter(String), } impl Convert for Term { @@ -446,6 +535,20 @@ impl Convert for Term { // The error is caught in the `add_xxx` functions, so this should // not happen™ Term::Parameter(s) => panic!("Remaining parameter {}", &s), + Term::Array(a) => datalog::Term::Array(a.iter().map(|i| i.convert(symbols)).collect()), + Term::Map(m) => datalog::Term::Map( + m.iter() + .map(|(key, term)| { + let key = match key { + MapKey::Integer(i) => datalog::MapKey::Integer(*i), + MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(s)), + MapKey::Parameter(s) => panic!("Remaining parameter {}", &s), + }; + + (key, term.convert(symbols)) + }) + .collect(), + ), } } @@ -463,6 +566,23 @@ impl Convert for Term { .collect::, error::Format>>()?, ), datalog::Term::Null => Term::Null, + datalog::Term::Array(a) => Term::Array( + a.iter() + .map(|i| Term::convert_from(i, symbols)) + .collect::, error::Format>>()?, + ), + datalog::Term::Map(m) => Term::Map( + m.iter() + .map(|(key, term)| { + let key = match key { + datalog::MapKey::Integer(i) => Ok(MapKey::Integer(*i)), + datalog::MapKey::Str(s) => symbols.print_symbol(*s).map(MapKey::Str), + }; + + key.and_then(|k| Term::convert_from(term, symbols).map(|term| (k, term))) + }) + .collect::, error::Format>>()?, + ), }) } } @@ -479,6 +599,8 @@ impl From<&Term> for Term { Term::Set(ref s) => Term::Set(s.clone()), Term::Parameter(ref p) => Term::Parameter(p.clone()), Term::Null => Term::Null, + Term::Array(ref a) => Term::Array(a.clone()), + Term::Map(m) => Term::Map(m.clone()), } } } @@ -497,6 +619,25 @@ impl From for Term { } biscuit_parser::builder::Term::Null => Term::Null, biscuit_parser::builder::Term::Parameter(ref p) => Term::Parameter(p.clone()), + biscuit_parser::builder::Term::Array(a) => { + Term::Array(a.into_iter().map(|t| t.into()).collect()) + } + biscuit_parser::builder::Term::Map(a) => Term::Map( + a.into_iter() + .map(|(key, term)| { + ( + match key { + biscuit_parser::builder::MapKey::Parameter(s) => { + MapKey::Parameter(s) + } + biscuit_parser::builder::MapKey::Integer(i) => MapKey::Integer(i), + biscuit_parser::builder::MapKey::Str(s) => MapKey::Str(s), + }, + term.into(), + ) + }) + .collect(), + ), } } } @@ -534,12 +675,27 @@ impl fmt::Display for Term { } Term::Set(s) => { let terms = s.iter().map(|term| term.to_string()).collect::>(); - write!(f, "[{}]", terms.join(", ")) + write!(f, "{{{}}}", terms.join(", ")) } Term::Parameter(s) => { write!(f, "{{{}}}", s) } Term::Null => write!(f, "null"), + Term::Array(a) => { + let terms = a.iter().map(|term| term.to_string()).collect::>(); + write!(f, "[{}]", terms.join(", ")) + } + Term::Map(m) => { + let terms = m + .iter() + .map(|(key, term)| match key { + MapKey::Integer(i) => format!("{i}: {}", term.to_string()), + MapKey::Str(s) => format!("\"{s}\": {}", term.to_string()), + MapKey::Parameter(s) => format!("{{{s}}}: {}", term.to_string()), + }) + .collect::>(); + write!(f, "{{{}}}", terms.join(", ")) + } } } } @@ -696,9 +852,7 @@ impl Fact { let terms: Vec = terms.into(); for term in &terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } Fact { predicate: Predicate::new(name, terms), @@ -802,14 +956,7 @@ impl Fact { .predicate .terms .drain(..) - .map(|t| { - if let Term::Parameter(name) = &t { - if let Some(Some(term)) = parameters.get(name) { - return term.clone(); - } - } - t - }) + .map(|t| t.apply_parameters(¶meters)) .collect(); } } @@ -996,6 +1143,7 @@ impl From for Binary { biscuit_parser::builder::Binary::LazyOr => Binary::LazyOr, biscuit_parser::builder::Binary::All => Binary::All, biscuit_parser::builder::Binary::Any => Binary::Any, + biscuit_parser::builder::Binary::Get => Binary::Get, } } } @@ -1021,23 +1169,19 @@ impl Rule { let mut parameters = HashMap::new(); let mut scope_parameters = HashMap::new(); for term in &head.terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } for predicate in &body { for term in &predicate.terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } } for expression in &expressions { for op in &expression.ops { - if let Op::Value(Term::Parameter(name)) = &op { - parameters.insert(name.to_string(), None); + if let Op::Value(term) = &op { + term.extract_parameters(&mut parameters); } } } @@ -1904,6 +2048,13 @@ pub trait ToAnyParam { fn to_any_param(&self) -> AnyParam; } +#[cfg(feature = "datalog-macro")] +impl ToAnyParam for Term { + fn to_any_param(&self) -> AnyParam { + AnyParam::Term(self.clone()) + } +} + impl From for Term { fn from(i: i64) -> Self { Term::Integer(i) @@ -2094,6 +2245,37 @@ impl> TryFrom for BTreeSet } } +// TODO: From and ToAnyParam for arrays and maps +impl TryFrom for Term { + type Error = &'static str; + + fn try_from(value: serde_json::Value) -> Result { + match value { + serde_json::Value::Null => Ok(Term::Null), + serde_json::Value::Bool(b) => Ok(Term::Bool(b)), + serde_json::Value::Number(i) => match i.as_i64() { + Some(i) => Ok(Term::Integer(i)), + None => Err("Biscuit values do not support floating point numbers"), + }, + serde_json::Value::String(s) => Ok(Term::Str(s)), + serde_json::Value::Array(array) => Ok(Term::Array( + array + .into_iter() + .map(|v| v.try_into()) + .collect::>()?, + )), + serde_json::Value::Object(o) => Ok(Term::Map( + o.into_iter() + .map(|(key, value)| { + let value: Term = value.try_into()?; + Ok::<_, &'static str>((MapKey::Str(key), value)) + }) + .collect::>()?, + )), + } + } +} + macro_rules! tuple_try_from( ($ty1:ident, $ty2:ident, $($ty:ident),*) => ( tuple_try_from!(__impl $ty1, $ty2; $($ty),*); @@ -2379,7 +2561,7 @@ mod tests { rule.set("p5", term_set).unwrap(); let s = rule.to_string(); - assert_eq!(s, "fact($var1, \"hello\", [0]) <- f1($var1, $var3), f2(\"hello\", $var3, 1), $var3.starts_with(\"hello\")"); + assert_eq!(s, "fact($var1, \"hello\", {0}) <- f1($var1, $var3), f2(\"hello\", $var3, 1), $var3.starts_with(\"hello\")"); } #[test] diff --git a/biscuit-auth/src/token/mod.rs b/biscuit-auth/src/token/mod.rs index a93fc3e3..5f8bda3e 100644 --- a/biscuit-auth/src/token/mod.rs +++ b/biscuit-auth/src/token/mod.rs @@ -1333,14 +1333,14 @@ mod tests { let mut block2 = BlockBuilder::new(); block2 - .add_rule("has_bytes($0) <- bytes($0), [ hex:00000000, hex:0102AB ].contains($0)") + .add_rule("has_bytes($0) <- bytes($0), { hex:00000000, hex:0102AB }.contains($0)") .unwrap(); let keypair2 = KeyPair::new_with_rng(&mut rng); let biscuit2 = biscuit1.append_with_keypair(&keypair2, block2).unwrap(); let mut authorizer = biscuit2.authorizer().unwrap(); authorizer - .add_check("check if bytes($0), [ hex:00000000, hex:0102AB ].contains($0)") + .add_check("check if bytes($0), { hex:00000000, hex:0102AB }.contains($0)") .unwrap(); authorizer.allow().unwrap(); diff --git a/biscuit-auth/tests/macros.rs b/biscuit-auth/tests/macros.rs index f38a14ea..b2146294 100644 --- a/biscuit-auth/tests/macros.rs +++ b/biscuit-auth/tests/macros.rs @@ -1,20 +1,24 @@ -use biscuit_auth::builder; +use biscuit_auth::{builder, KeyPair}; use biscuit_quote::{ authorizer, authorizer_merge, biscuit, biscuit_merge, block, block_merge, check, fact, policy, rule, }; -use std::collections::BTreeSet; +use serde_json::json; +use std::{collections::BTreeSet, convert::TryInto}; #[test] fn block_macro() { let mut term_set = BTreeSet::new(); term_set.insert(builder::int(0i64)); let my_key = "my_value"; + let array_param = 2; + let mapkey = "hello"; + let mut b = block!( - r#"fact("test", hex:aabbcc, [true], {my_key}, {term_set}); + r#"fact("test", hex:aabbcc, [1, {array_param}], {my_key}, {term_set}, {"a": 1, 2 : "abcd", {mapkey}: 0 }); rule($0, true) <- fact($0, $1, $2, {my_key}), true || false; check if {my_key}.starts_with("my"); - check if [true,false].any($p -> true); + check if {true,false}.any($p -> true); "#, ); @@ -23,11 +27,11 @@ fn block_macro() { assert_eq!( b.to_string(), - r#"fact("test", hex:aabbcc, [true], "my_value", [0]); + r#"fact("test", hex:aabbcc, [1, 2], "my_value", {0}, {2: "abcd", "a": 1, "hello": 0}); appended(true); rule($0, true) <- fact($0, $1, $2, "my_value"), true || false; check if "my_value".starts_with("my"); -check if [false, true].any($p -> true); +check if {false, true}.any($p -> true); "#, ); } @@ -167,7 +171,7 @@ fn rule_macro() { assert_eq!( r.to_string(), - r#"rule($0, true) <- fact($0, $1, $2, "my_value", [0]) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, + r#"rule($0, true) <- fact($0, $1, $2, "my_value", {0}) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, ); } @@ -177,7 +181,7 @@ fn fact_macro() { term_set.insert(builder::int(0i64)); let f = fact!(r#"fact({my_key}, {term_set})"#, my_key = "my_value",); - assert_eq!(f.to_string(), r#"fact("my_value", [0])"#,); + assert_eq!(f.to_string(), r#"fact("my_value", {0})"#,); } #[test] @@ -196,7 +200,7 @@ fn check_macro() { assert_eq!( c.to_string(), - r#"check if fact("my_value", [0]) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, + r#"check if fact("my_value", {0}) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, ); } @@ -216,6 +220,32 @@ fn policy_macro() { assert_eq!( p.to_string(), - r#"allow if fact("my_value", [0]) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, + r#"allow if fact("my_value", {0}) trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db"#, ); } + +#[test] +fn json() { + let key_pair = KeyPair::new(); + let biscuit = biscuit!(r#"user(123)"#).build(&key_pair).unwrap(); + + let value: serde_json::Value = json!( + { + "id": 123, + "roles": ["admin"] + } + ); + let json_value: biscuit_auth::builder::Term = value.try_into().unwrap(); + + let authorizer = authorizer!( + r#" + user_roles({json_value}); + allow if + user($id), + user_roles($value), + $value.get("id") == $id, + $value.get("roles").contains("admin");"# + ); + + assert!(biscuit.authorize(&authorizer).is_ok()); +} diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index bf632b31..ea18e95f 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -1,6 +1,6 @@ //! helper functions and structure to create tokens and blocks use std::{ - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, time::{SystemTime, UNIX_EPOCH}, }; @@ -19,6 +19,44 @@ pub enum Term { Set(BTreeSet), Parameter(String), Null, + Array(Vec), + Map(BTreeMap), +} + +impl Term { + fn extract_parameters(&self, parameters: &mut HashMap>) { + match self { + Term::Parameter(name) => { + parameters.insert(name.to_string(), None); + } + Term::Set(s) => { + for term in s { + term.extract_parameters(parameters); + } + } + Term::Array(a) => { + for term in a { + term.extract_parameters(parameters); + } + } + Term::Map(m) => { + for (key, term) in m { + if let MapKey::Parameter(name) = key { + parameters.insert(name.to_string(), None); + } + term.extract_parameters(parameters); + } + } + _ => {} + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum MapKey { + Parameter(String), + Integer(i64), + Str(String), } impl From<&Term> for Term { @@ -33,6 +71,8 @@ impl From<&Term> for Term { Term::Set(ref s) => Term::Set(s.clone()), Term::Parameter(ref p) => Term::Parameter(p.clone()), Term::Null => Term::Null, + Term::Array(ref a) => Term::Array(a.clone()), + Term::Map(ref m) => Term::Map(m.clone()), } } } @@ -61,11 +101,47 @@ impl ToTokens for Term { }} } Term::Null => quote! { ::biscuit_auth::builder::Term::Null }, - + Term::Array(v) => { + quote! {{ + use std::iter::FromIterator; + ::biscuit_auth::builder::Term::Array(::std::vec::Vec::from_iter(<[::biscuit_auth::builder::Term]>::into_vec( Box::new([ #(#v),*])))) + }} + } + Term::Map(m) => { + let it = m.iter().map(|(key, term)| MapEntry {key, term }); + quote! {{ + use std::iter::FromIterator; + ::biscuit_auth::builder::Term::Map(::std::collections::BTreeMap::from_iter(<[(::biscuit_auth::builder::MapKey,::biscuit_auth::builder::Term)]>::into_vec(Box::new([ #(#it),*])))) + }} + } }) } } +#[cfg(feature = "datalog-macro")] +struct MapEntry<'a> { + key: &'a MapKey, + term: &'a Term, +} + +#[cfg(feature = "datalog-macro")] +impl<'a> ToTokens for MapEntry<'a> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let term = self.term; + tokens.extend(match self.key { + MapKey::Parameter(p) => { + quote! { (::biscuit_auth::builder::MapKey::Parameter(#p.to_string()) , #term )} + } + MapKey::Integer(i) => { + quote! { (::biscuit_auth::builder::MapKey::Integer(#i) , #term )} + } + MapKey::Str(s) => { + quote! { (::biscuit_auth::builder::MapKey::Str(#s.to_string()) , #term )} + } + }); + } +} + #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub enum Scope { Authority, @@ -136,9 +212,7 @@ impl Fact { let terms: Vec = terms.into(); for term in &terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } Fact { predicate: Predicate::new(name, terms), @@ -224,6 +298,7 @@ pub enum Binary { LazyOr, All, Any, + Get, } #[cfg(feature = "datalog-macro")] @@ -289,6 +364,7 @@ impl ToTokens for Binary { Binary::LazyOr => quote! { ::biscuit_auth::datalog::Binary::LazyOr }, Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, + Binary::Get => quote! { ::biscuit_auth::datalog::Binary::Get }, }); } } @@ -317,23 +393,19 @@ impl Rule { let mut scope_parameters = HashMap::new(); for term in &head.terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } for predicate in &body { for term in &predicate.terms { - if let Term::Parameter(name) = &term { - parameters.insert(name.to_string(), None); - } + term.extract_parameters(&mut parameters); } } for expression in &expressions { for op in &expression.ops { - if let Op::Value(Term::Parameter(name)) = &op { - parameters.insert(name.to_string(), None); + if let Op::Value(term) = &op { + term.extract_parameters(&mut parameters); } } } @@ -605,6 +677,16 @@ pub fn null() -> Term { Term::Null } +/// creates an array +pub fn array(a: Vec) -> Term { + Term::Array(a) +} + +/// creates a map +pub fn map(m: BTreeMap) -> Term { + Term::Map(m) +} + /// creates a parameter pub fn parameter(p: &str) -> Term { Term::Parameter(p.to_string()) diff --git a/biscuit-parser/src/parser.rs b/biscuit-parser/src/parser.rs index 43c37987..fda67dab 100644 --- a/biscuit-parser/src/parser.rs +++ b/biscuit-parser/src/parser.rs @@ -3,16 +3,19 @@ use nom::{ branch::alt, bytes::complete::{escaped_transform, tag, tag_no_case, take_until, take_while, take_while1}, character::{ - complete::{char, digit1, multispace0 as space0}, - is_alphanumeric, + complete::{char, digit1, multispace0 as space0, satisfy}, + is_alphabetic, is_alphanumeric, }, combinator::{consumed, cut, eof, map, map_res, opt, recognize, value}, error::{ErrorKind, FromExternalError, ParseError}, multi::{many0, separated_list0, separated_list1}, - sequence::{delimited, pair, preceded, terminated, tuple}, + sequence::{delimited, pair, preceded, separated_pair, terminated, tuple}, IResult, Offset, }; -use std::{collections::BTreeSet, convert::TryInto}; +use std::{ + collections::{BTreeMap, BTreeSet}, + convert::TryInto, +}; use thiserror::Error; /// parse a Datalog fact @@ -506,6 +509,7 @@ fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { value(Binary::Union, tag("union")), value(Binary::All, tag("all")), value(Binary::Any, tag("any")), + value(Binary::Get, tag("get")), ))(i) } @@ -728,6 +732,21 @@ fn name(i: &str) -> IResult<&str, &str, Error> { reduce(take_while1(is_name_char), " ,:(\n;")(i) } +fn parameter_name(i: &str) -> IResult<&str, &str, Error> { + let is_name_char = |c: char| is_alphanumeric(c as u8) || c == '_' || c == ':'; + + error( + recognize(preceded( + satisfy(|c: char| is_alphabetic(c as u8)), + take_while(is_name_char), + )), + |_| { + "invalid parameter name: it must start with an alphabetic character, followed by alphanumeric characters, underscores or colons".to_string() + }, + " ,:(\n;", + )(i) +} + fn printable(i: &str) -> IResult<&str, &str, Error> { take_while1(|c: char| c != '\\' && c != '"')(i) } @@ -766,7 +785,9 @@ fn integer(i: &str) -> IResult<&str, builder::Term, Error> { fn parse_date(i: &str) -> IResult<&str, u64, Error> { map_res( map_res( - take_while1(|c: char| c != ',' && c != ' ' && c != ')' && c != ']' && c != ';'), + take_while1(|c: char| { + c != ',' && c != ' ' && c != ')' && c != ']' && c != ';' && c != '}' + }), |s| time::OffsetDateTime::parse(s, &time::format_description::well_known::Rfc3339), ), |t| t.unix_timestamp().try_into(), @@ -800,7 +821,10 @@ fn variable(i: &str) -> IResult<&str, builder::Term, Error> { } fn parameter(i: &str) -> IResult<&str, builder::Term, Error> { - map(delimited(char('{'), name, char('}')), builder::parameter)(i) + map( + delimited(char('{'), parameter_name, char('}')), + builder::parameter, + )(i) } fn parse_bool(i: &str) -> IResult<&str, bool, Error> { @@ -816,8 +840,7 @@ fn null(i: &str) -> IResult<&str, builder::Term, Error> { } fn set(i: &str) -> IResult<&str, builder::Term, Error> { - //println!("set:\t{}", i); - let (i, _) = preceded(space0, char('['))(i)?; + let (i, _) = preceded(space0, char('{'))(i)?; let (i, mut list) = cut(separated_list0(preceded(space0, char(',')), term_in_set))(i)?; let mut set = BTreeSet::new(); @@ -846,6 +869,8 @@ fn set(i: &str) -> IResult<&str, builder::Term, Error> { } builder::Term::Parameter(_) => 7, builder::Term::Null => 8, + builder::Term::Array(_) => 9, + builder::Term::Map(_) => 10, }; if let Some(k) = kind { @@ -863,16 +888,55 @@ fn set(i: &str) -> IResult<&str, builder::Term, Error> { set.insert(term); } - let (i, _) = preceded(space0, char(']'))(i)?; + let (i, _) = preceded(space0, char('}'))(i)?; Ok((i, builder::set(set))) } +fn array(i: &str) -> IResult<&str, builder::Term, Error> { + let (i, _) = preceded(space0, char('['))(i)?; + let (i, array) = cut(separated_list0(preceded(space0, char(',')), term_in_fact))(i)?; + let (i, _) = preceded(space0, char(']'))(i)?; + + Ok((i, builder::array(array))) +} + +fn parse_map(i: &str) -> IResult<&str, builder::Term, Error> { + let (i, _) = preceded(space0, char('{'))(i)?; + let (i, mut list) = cut(separated_list0( + preceded(space0, char(',')), + separated_pair(map_key, preceded(space0, char(':')), term_in_fact), + ))(i)?; + + let mut map = BTreeMap::new(); + + for (key, term) in list.drain(..) { + map.insert(key, term); + } + + let (i, _) = preceded(space0, char('}'))(i)?; + + Ok((i, builder::map(map))) +} + +fn map_key(i: &str) -> IResult<&str, builder::MapKey, Error> { + preceded( + space0, + alt(( + map(delimited(char('{'), parameter_name, char('}')), |s| { + builder::MapKey::Parameter(s.to_string()) + }), + map(parse_string, |s| builder::MapKey::Str(s.to_string())), + map(parse_integer, builder::MapKey::Integer), + )), + )(i) +} + fn term(i: &str) -> IResult<&str, builder::Term, Error> { preceded( space0, alt(( - parameter, string, date, variable, integer, bytes, boolean, null, set, + parameter, string, date, variable, integer, bytes, boolean, null, set, array, parse_map, )), )(i) } @@ -881,7 +945,9 @@ fn term_in_fact(i: &str) -> IResult<&str, builder::Term, Error> { preceded( space0, error( - alt((parameter, string, date, integer, bytes, boolean, null, set)), + alt(( + parameter, string, date, integer, bytes, boolean, null, set, array, parse_map, + )), |input| match input.chars().next() { None | Some(',') | Some(')') => "missing term".to_string(), Some('$') => "variables are not allowed in facts".to_string(), @@ -896,13 +962,15 @@ fn term_in_set(i: &str) -> IResult<&str, builder::Term, Error> { preceded( space0, error( - alt((parameter, string, date, integer, bytes, boolean, null)), + alt(( + parameter, string, date, integer, bytes, boolean, null, parse_map, + )), |input| match input.chars().next() { - None | Some(',') | Some(']') => "missing term".to_string(), + None | Some(',') | Some('}') => "missing term".to_string(), Some('$') => "variables are not allowed in sets".to_string(), _ => "expected a valid term".to_string(), }, - " ,]\n;", + " ,}\n;", ), )(i) } @@ -1235,7 +1303,7 @@ where #[cfg(test)] mod tests { - use crate::builder::{self, Unary}; + use crate::builder::{self, array, int, var, Binary, Op, Unary}; #[test] fn name() { @@ -1283,6 +1351,17 @@ mod tests { super::parameter("{param}"), Ok(("", builder::parameter("param"))) ); + + assert_eq!( + super::parameter("{1param}"), + Err(nom::Err::Error(crate::parser::Error { + input: "1param}", + code: nom::error::ErrorKind::Satisfy, + message: Some("invalid parameter name: it must start with an alphabetic character, followed by alphanumeric characters, underscores or colons".to_string()) + })) + ); + + assert_eq!(super::parameter("{p}"), Ok(("", builder::parameter("p")))); } #[test] @@ -1639,7 +1718,7 @@ mod tests { let h = [int(1), int(2)].iter().cloned().collect::>(); assert_eq!( - super::expr("[1, 2].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("{1, 2}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -1651,7 +1730,7 @@ mod tests { ); assert_eq!( - super::expr("![1, 2].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("!{ 1, 2}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -1663,6 +1742,25 @@ mod tests { )) ); + let h = [ + builder::Term::Date(1575452801), + builder::Term::Date(1607075201), + ] + .iter() + .cloned() + .collect::>(); + assert_eq!( + super::expr("{2020-12-04T09:46:41+00:00, 2019-12-04T09:46:41+00:00}.contains(2020-12-04T09:46:41+00:00)").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![ + Op::Value(set(h)), + Op::Value(builder::Term::Date(1607075201)), + Op::Binary(Binary::Contains), + ], + )) + ); + assert_eq!( super::expr("$0 === \"abc\"").map(|(i, o)| (i, o.opcodes())), Ok(( @@ -1752,7 +1850,7 @@ mod tests { .cloned() .collect::>(); assert_eq!( - super::expr("[\"abc\", \"def\"].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("{\"abc\", \"def\"}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -1764,7 +1862,7 @@ mod tests { ); assert_eq!( - super::expr("![\"abc\", \"def\"].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("!{\"abc\", \"def\"}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -1781,7 +1879,7 @@ mod tests { .cloned() .collect::>(); assert_eq!( - super::expr("[\"abc\", \"def\"].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("{\"abc\", \"def\"}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -1793,7 +1891,7 @@ mod tests { ); assert_eq!( - super::expr("![\"abc\", \"def\"].contains($0)").map(|(i, o)| (i, o.opcodes())), + super::expr("!{\"abc\", \"def\"}.contains($0)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -2453,7 +2551,7 @@ mod tests { use builder::{int, set, Binary, Op}; assert_eq!( - super::expr("[1].intersection([2]).contains(3)").map(|(i, o)| (i, o.opcodes())), + super::expr("{1}.intersection({2}).contains(3)").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -2467,7 +2565,7 @@ mod tests { ); assert_eq!( - super::expr("[1].intersection([2]).union([3]).length()").map(|(i, o)| (i, o.opcodes())), + super::expr("{1}.intersection({2}).union({3}).length()").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -2482,7 +2580,7 @@ mod tests { ); assert_eq!( - super::expr("[1].intersection([2]).length().union([3])").map(|(i, o)| (i, o.opcodes())), + super::expr("{1}.intersection({2}).length().union({3})").map(|(i, o)| (i, o.opcodes())), Ok(( "", vec![ @@ -2496,4 +2594,20 @@ mod tests { )) ); } + + #[test] + fn arrays() { + let h = vec![int(1), int(2)]; + assert_eq!( + super::expr("[1, 2].contains($0)").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![ + Op::Value(array(h.clone())), + Op::Value(var("0")), + Op::Binary(Binary::Contains), + ], + )) + ); + } }