From 41d268ae1860dff9d45ad4c352c92218b6027f42 Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 31 Jul 2024 16:53:03 +0200 Subject: [PATCH 01/12] GraphRAG Global Search --- Cargo.lock | 874 +++++++++++++++++- Cargo.toml | 3 +- shinkai-libs/shinkai-graphrag/Cargo.toml | 13 + .../src/context_builder/context_builder.rs | 18 + .../src/context_builder/mod.rs | 1 + shinkai-libs/shinkai-graphrag/src/lib.rs | 3 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 35 + shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 + .../src/search/global_search.rs | 255 +++++ .../shinkai-graphrag/src/search/mod.rs | 1 + 10 files changed, 1163 insertions(+), 41 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/Cargo.toml create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/lib.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/llm.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b51e07de2..dd8b827ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -224,6 +224,21 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "argminmax" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" +dependencies = [ + "num-traits", +] + +[[package]] +name = "array-init-cursor" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" + [[package]] name = "arrayref" version = "0.3.7" @@ -375,7 +390,7 @@ dependencies = [ "arrow-schema 51.0.0", "arrow-select 51.0.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -396,7 +411,7 @@ dependencies = [ "arrow-schema 52.2.0", "arrow-select 52.2.0", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -875,6 +890,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atoi_simd" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" + [[package]] name = "atomic-waker" version = "1.1.1" @@ -1435,9 +1456,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64-simd" @@ -1700,6 +1721,20 @@ name = "bytemuck" version = "1.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] [[package]] name = "byteorder" @@ -1915,6 +1950,17 @@ dependencies = [ "parse-zoneinfo", ] +[[package]] +name = "chrono-tz" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +dependencies = [ + "chrono", + "chrono-tz-build 0.2.1", + "phf 0.11.2", +] + [[package]] name = "chrono-tz" version = "0.9.0" @@ -1922,8 +1968,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" dependencies = [ "chrono", - "chrono-tz-build", + "chrono-tz-build 0.3.0", + "phf 0.11.2", +] + +[[package]] +name = "chrono-tz-build" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +dependencies = [ + "parse-zoneinfo", "phf 0.11.2", + "phf_codegen 0.11.2", ] [[package]] @@ -2153,6 +2210,7 @@ version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ + "crossterm", "strum 0.26.3", "strum_macros 0.26.4", "unicode-width", @@ -2451,6 +2509,28 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.4.0", + "crossterm_winapi", + "libc", + "parking_lot 0.12.1", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -2819,7 +2899,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" dependencies = [ "arrow 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -2906,7 +2986,7 @@ dependencies = [ "arrow-ord 52.2.0", "arrow-schema 52.2.0", "arrow-string 52.2.0", - "base64 0.22.0", + "base64 0.22.1", "chrono", "datafusion-common", "datafusion-execution", @@ -3278,6 +3358,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "ecdsa" version = "0.14.8" @@ -3337,9 +3423,9 @@ checksum = "3a68a4904193147e0a8dec3314640e6db742afd5f6e634f428a6af230d9b3591" [[package]] name = "either" -version = "1.9.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "elliptic-curve" @@ -3422,6 +3508,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "enum_dispatch" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" +dependencies = [ + "once_cell", + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "env_logger" version = "0.9.3" @@ -3778,6 +3876,12 @@ dependencies = [ "yansi", ] +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + [[package]] name = "event-listener" version = "2.5.3" @@ -3821,6 +3925,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -3831,6 +3941,12 @@ dependencies = [ "regex", ] +[[package]] +name = "fast-float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" + [[package]] name = "fastdivide" version = "0.4.1" @@ -3993,6 +4109,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -4441,6 +4563,8 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash 0.8.11", "allocator-api2", + "rayon", + "serde", ] [[package]] @@ -4521,9 +4645,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -4854,7 +4978,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows", + "windows 0.48.0", ] [[package]] @@ -5055,7 +5179,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", "windows-sys 0.48.0", ] @@ -5088,7 +5212,7 @@ version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "rustix 0.38.32", "windows-sys 0.48.0", ] @@ -5141,6 +5265,12 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "itoap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" + [[package]] name = "jetscii" version = "0.5.3" @@ -6028,11 +6158,21 @@ version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9106e1d747ffd48e6be5bb2d97fa706ed25b144fbee4d5c02eae110cd8d6badd" +[[package]] +name = "lz4" +version = "1.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958b4caa893816eea05507c20cfe47574a43d9a697138a7872990bba8a0ece68" +dependencies = [ + "libc", + "lz4-sys", +] + [[package]] name = "lz4-sys" -version = "1.9.4" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d27b317e207b10f69f5e75494119e391a96f48861ae870d1da6edac98ca900" +checksum = "109de74d5d2353660401699a4174a4ff23fcc649caf553df71933c7fb45ad868" dependencies = [ "cc", "libc", @@ -6072,6 +6212,12 @@ dependencies = [ "libc", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "markup5ever" version = "0.10.1" @@ -6181,6 +6327,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + [[package]] name = "memmap2" version = "0.9.4" @@ -6297,13 +6452,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ + "hermit-abi 0.3.9", "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -6409,6 +6565,28 @@ dependencies = [ "twoway", ] +[[package]] +name = "multiversion" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" +dependencies = [ + "multiversion-macros", + "target-features", +] + +[[package]] +name = "multiversion-macros" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 1.0.109", + "target-features", +] + [[package]] name = "murmurhash32" version = "0.3.1" @@ -6524,6 +6702,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" +[[package]] +name = "now" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" +dependencies = [ + "chrono", +] + +[[package]] +name = "ntapi" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +dependencies = [ + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -6642,7 +6838,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.2", + "hermit-abi 0.3.9", "libc", ] @@ -6683,7 +6879,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" dependencies = [ "async-trait", - "base64 0.22.0", + "base64 0.22.1", "bytes", "chrono", "futures", @@ -7088,6 +7284,12 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "parquet-format-safe" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" + [[package]] name = "parse-zoneinfo" version = "0.3.0" @@ -7174,6 +7376,28 @@ dependencies = [ "hmac 0.12.1", ] +[[package]] +name = "pcre2" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be55c43ac18044541d58d897e8f4c55157218428953ebd39d86df3ba0286b2b" +dependencies = [ + "libc", + "log 0.4.21", + "pcre2-sys", +] + +[[package]] +name = "pcre2-sys" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "550f5d18fb1b90c20b87e161852c10cde77858c3900c5059b5ad2a1449f11d8a" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "pddl-ish-parser" version = "0.0.4" @@ -7554,6 +7778,15 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "planus" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1691dd09e82f428ce8d6310bd6d5da2557c82ff17694d2a32cad7242aea89f" +dependencies = [ + "array-init-cursor", +] + [[package]] name = "platforms" version = "3.2.0" @@ -7608,6 +7841,402 @@ dependencies = [ "miniz_oxide 0.7.1", ] +[[package]] +name = "polars" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3351ea4570e54cd556e6755b78fe7a2c85368d820c0307cca73c96e796a7ba" +dependencies = [ + "getrandom 0.2.10", + "polars-arrow", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-sql", + "polars-time", + "polars-utils", + "version_check 0.9.4", +] + +[[package]] +name = "polars-arrow" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba65fc4bcabbd64fca01fd30e759f8b2043f0963c57619e331d4b534576c0b47" +dependencies = [ + "ahash 0.8.11", + "atoi", + "atoi_simd", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "dyn-clone", + "either", + "ethnum", + "fast-float", + "foreign_vec", + "getrandom 0.2.10", + "hashbrown 0.14.5", + "itoa 1.0.9", + "itoap", + "lz4", + "multiversion", + "num-traits", + "polars-arrow-format", + "polars-error", + "polars-utils", + "ryu", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "version_check 0.9.4", + "zstd 0.13.2", +] + +[[package]] +name = "polars-arrow-format" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b0ef2474af9396b19025b189d96e992311e6a47f90c53cd998b36c4c64b84c" +dependencies = [ + "planus", + "serde", +] + +[[package]] +name = "polars-compute" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f099516af30ac9ae4b4480f4ad02aa017d624f2f37b7a16ad4e9ba52f7e5269" +dependencies = [ + "bytemuck", + "either", + "num-traits", + "polars-arrow", + "polars-error", + "polars-utils", + "strength_reduce", + "version_check 0.9.4", +] + +[[package]] +name = "polars-core" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2439484be228b8c302328e2f953e64cfd93930636e5c7ceed90339ece7fef6c" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "comfy-table", + "either", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-row", + "polars-utils", + "rand 0.8.5", + "rand_distr", + "rayon", + "regex", + "smartstring", + "thiserror", + "version_check 0.9.4", + "xxhash-rust", +] + +[[package]] +name = "polars-error" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9b06dfbe79cabe50a7f0a90396864b5ee2c0e0f8d6a9353b2343c29c56e937" +dependencies = [ + "polars-arrow-format", + "regex", + "simdutf8", + "thiserror", +] + +[[package]] +name = "polars-expr" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c630385a56a867c410a20f30772d088f90ec3d004864562b84250b35268f97" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "once_cell", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", +] + +[[package]] +name = "polars-io" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" +dependencies = [ + "ahash 0.8.11", + "atoi_simd", + "bytes", + "chrono", + "fast-float", + "home", + "itoa 1.0.9", + "memchr", + "memmap2 0.7.1", + "num-traits", + "once_cell", + "percent-encoding 2.3.0", + "polars-arrow", + "polars-core", + "polars-error", + "polars-time", + "polars-utils", + "rayon", + "regex", + "ryu", + "simdutf8", + "smartstring", +] + +[[package]] +name = "polars-lazy" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03877e74e42b5340ae52ded705f6d5d14563d90554c9177b01b91ed2412a56ed" +dependencies = [ + "ahash 0.8.11", + "bitflags 2.4.0", + "glob", + "memchr", + "once_cell", + "polars-arrow", + "polars-core", + "polars-expr", + "polars-io", + "polars-mem-engine", + "polars-ops", + "polars-pipe", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", + "smartstring", + "version_check 0.9.4", +] + +[[package]] +name = "polars-mem-engine" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea9e17771af750c94bf959885e4b3f5b14149576c62ef3ec1c9ef5827b2a30f" +dependencies = [ + "polars-arrow", + "polars-core", + "polars-error", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-time", + "polars-utils", + "rayon", +] + +[[package]] +name = "polars-ops" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6066552eb577d43b307027fb38096910b643ffb2c89a21628c7e41caf57848d0" +dependencies = [ + "ahash 0.8.11", + "argminmax", + "base64 0.22.1", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "hex", + "indexmap 2.1.0", + "memchr", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-error", + "polars-utils", + "rayon", + "regex", + "smartstring", + "unicode-reverse", + "version_check 0.9.4", +] + +[[package]] +name = "polars-parquet" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" +dependencies = [ + "ahash 0.8.11", + "base64 0.22.1", + "ethnum", + "num-traits", + "parquet-format-safe", + "polars-arrow", + "polars-compute", + "polars-error", + "polars-utils", + "simdutf8", + "streaming-decompression", +] + +[[package]] +name = "polars-pipe" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021bce7768c330687d735340395a77453aa18dd70d57c184cbb302311e87c1b9" +dependencies = [ + "crossbeam-channel", + "crossbeam-queue", + "enum_dispatch", + "hashbrown 0.14.5", + "num-traits", + "polars-arrow", + "polars-compute", + "polars-core", + "polars-expr", + "polars-io", + "polars-ops", + "polars-plan", + "polars-row", + "polars-utils", + "rayon", + "smartstring", + "uuid 1.8.0", + "version_check 0.9.4", +] + +[[package]] +name = "polars-plan" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "220d0d7c02d1c4375802b2813dbedcd1a184df39c43b74689e729ede8d5c2921" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "chrono-tz 0.8.6", + "either", + "hashbrown 0.14.5", + "once_cell", + "percent-encoding 2.3.0", + "polars-arrow", + "polars-core", + "polars-io", + "polars-ops", + "polars-time", + "polars-utils", + "rayon", + "recursive", + "regex", + "smartstring", + "strum_macros 0.26.4", + "version_check 0.9.4", +] + +[[package]] +name = "polars-row" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1d70d87a2882a64a43b431aea1329cb9a2c4100547c95c417cc426bb82408b3" +dependencies = [ + "bytemuck", + "polars-arrow", + "polars-error", + "polars-utils", +] + +[[package]] +name = "polars-sql" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6fc1c9b778862f09f4a347f768dfdd3d0ba9957499d306d83c7103e0fa8dc5b" +dependencies = [ + "hex", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-lazy", + "polars-ops", + "polars-plan", + "polars-time", + "rand 0.8.5", + "serde", + "serde_json", + "sqlparser", +] + +[[package]] +name = "polars-time" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "179f98313a15c0bfdbc8cc0f1d3076d08d567485b9952d46439f94fbc3085df5" +dependencies = [ + "atoi", + "bytemuck", + "chrono", + "chrono-tz 0.8.6", + "now", + "once_cell", + "polars-arrow", + "polars-core", + "polars-error", + "polars-ops", + "polars-utils", + "regex", + "smartstring", +] + +[[package]] +name = "polars-utils" +version = "0.41.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53e6dd89fcccb1ec1a62f752c9a9f2d482a85e9255153f46efecc617b4996d50" +dependencies = [ + "ahash 0.8.11", + "bytemuck", + "hashbrown 0.14.5", + "indexmap 2.1.0", + "num-traits", + "once_cell", + "polars-error", + "raw-cpuid 11.0.1", + "rayon", + "smartstring", + "stacker", + "sysinfo", + "version_check 0.9.4", +] + [[package]] name = "polling" version = "2.8.0" @@ -7923,6 +8552,15 @@ dependencies = [ "prost 0.12.6", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "ptr_meta" version = "0.1.4" @@ -8449,6 +9087,26 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "redox_syscall" version = "0.3.5" @@ -8596,7 +9254,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -8894,6 +9552,16 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rust_decimal_macros" +version = "1.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a05bf7103af0797dbce0667c471946b29b9eaea34652eff67324f360fec027de" +dependencies = [ + "quote 1.0.36", + "rust_decimal", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -9014,7 +9682,7 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "rustls-pki-types", ] @@ -9366,11 +10034,11 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.8.1" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" dependencies = [ - "base64 0.22.0", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -9384,9 +10052,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.8.1" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" dependencies = [ "darling 0.20.10", "proc-macro2", @@ -9495,6 +10163,19 @@ dependencies = [ "dirs", ] +[[package]] +name = "shinkai-graphrag" +version = "0.1.0" +dependencies = [ + "async-trait", + "futures", + "polars", + "serde", + "serde_json", + "tiktoken", + "tokio", +] + [[package]] name = "shinkai_crypto_identities" version = "0.1.1" @@ -9537,7 +10218,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9717,7 +10398,7 @@ dependencies = [ "aes-gcm", "anyhow", "async-channel", - "base64 0.22.0", + "base64 0.22.1", "chrono", "chrono-tz 0.5.3", "clap 3.2.25", @@ -9973,6 +10654,17 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smartstring" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" +dependencies = [ + "autocfg 1.1.0", + "static_assertions", + "version_check 0.9.4", +] + [[package]] name = "snafu" version = "0.7.5" @@ -10100,6 +10792,19 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "winapi", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -10118,6 +10823,27 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" +[[package]] +name = "streaming-decompression" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" +dependencies = [ + "fallible-streaming-iterator", +] + +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "string_cache" version = "0.8.7" @@ -10295,6 +11021,20 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "sysinfo" +version = "0.30.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" +dependencies = [ + "cfg-if", + "core-foundation-sys", + "libc", + "ntapi", + "once_cell", + "windows 0.52.0", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -10343,7 +11083,7 @@ checksum = "f8d0582f186c0a6d55655d24543f15e43607299425c5ad8352c242b914b31856" dependencies = [ "aho-corasick", "arc-swap", - "base64 0.22.0", + "base64 0.22.1", "bitpacking", "byteorder", "census", @@ -10360,7 +11100,7 @@ dependencies = [ "lru 0.12.3", "lz4_flex", "measure_time", - "memmap2", + "memmap2 0.9.4", "num_cpus", "once_cell", "oneshot", @@ -10493,6 +11233,12 @@ dependencies = [ "xattr", ] +[[package]] +name = "target-features" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" + [[package]] name = "target-lexicon" version = "0.12.14" @@ -10622,6 +11368,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d8d54b2ba3a0f0b14743f9f0e0f6884ee03ad8789f1ce6a4b5650c0d74fad8" +dependencies = [ + "anyhow", + "base64 0.21.7", + "lazy_static", + "maplit", + "pcre2", + "rust_decimal", + "rust_decimal_macros", +] + [[package]] name = "time" version = "0.1.45" @@ -10722,22 +11483,21 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.39.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.6", "tokio-macros", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -10752,9 +11512,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", @@ -10794,9 +11554,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", @@ -11250,6 +12010,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-reverse" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" @@ -11758,6 +12527,25 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +dependencies = [ + "windows-core", + "windows-targets 0.52.3", +] + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.3", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -11993,6 +12781,12 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "xxhash-rust" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" + [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index b4fdec92c..7e4bb91c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,14 @@ members = [ "shinkai-libs/shinkai-dsl", "shinkai-libs/shinkai-sheet", "shinkai-libs/shinkai-fs-mirror", + "shinkai-libs/shinkai-graphrag", "shinkai-libs/shinkai-message-primitives", "shinkai-libs/shinkai-ocr", "shinkai-libs/shinkai-tcp-relayer", "shinkai-libs/shinkai-vector-resources", "shinkai-bin/*", "shinkai-cli-tools/*" -] +, "shinkai-libs/shinkai-graphrag"] resolver = "2" [workspace.package] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml new file mode 100644 index 000000000..591d7d48c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "shinkai-graphrag" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1.74" +futures = "0.3.30" +polars = "0.41.3" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.117" +tiktoken = "1.0.1" +tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs new file mode 100644 index 000000000..705818a77 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -0,0 +1,18 @@ +use async_trait::async_trait; +// use polars::prelude::*; +use std::collections::HashMap; + +// TODO: Serialize and Deserialize polars::frame::DataFrame +type DataFrame = Vec; + +#[async_trait] +pub trait GlobalContextBuilder { + /// Build the context for the global search mode. + async fn build_context( + &self, + conversation_history: Option, + context_builder_params: Option>, + ) -> (Vec, HashMap); +} + +pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs new file mode 100644 index 000000000..709d766d9 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -0,0 +1 @@ +pub mod context_builder; diff --git a/shinkai-libs/shinkai-graphrag/src/lib.rs b/shinkai-libs/shinkai-graphrag/src/lib.rs new file mode 100644 index 000000000..08bc3d655 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/lib.rs @@ -0,0 +1,3 @@ +pub mod context_builder; +pub mod llm; +pub mod search; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs new file mode 100644 index 000000000..de33da38b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; + +pub struct BaseLLMCallback { + response: Vec, +} + +impl BaseLLMCallback { + pub fn new() -> Self { + BaseLLMCallback { response: Vec::new() } + } + + pub fn on_llm_new_token(&mut self, token: &str) { + self.response.push(token.to_string()); + } +} + +#[async_trait] +pub trait BaseLLM { + async fn generate(&self, messages: Vec, streaming: bool, callbacks: Option>) + -> String; + + async fn agenerate( + &self, + messages: Vec, + streaming: bool, + callbacks: Option>, + ) -> String; +} + +#[async_trait] +pub trait BaseTextEmbedding { + async fn embed(&self, text: &str) -> Vec; + + async fn aembed(&self, text: &str) -> Vec; +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs new file mode 100644 index 000000000..214bbef7c --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -0,0 +1 @@ +pub mod llm; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs new file mode 100644 index 000000000..0a18ec55a --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -0,0 +1,255 @@ +use futures::future::join_all; +//use polars::frame::DataFrame; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Instant; +use tiktoken::encoding::Encoding; +use tokio::sync::Semaphore; + +use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; +use crate::llm::llm::BaseLLM; + +// TODO: Serialize and Deserialize polars::frame::DataFrame +type DataFrame = Vec; + +#[derive(Debug, Serialize, Deserialize)] +struct SearchResult { + response: ResponseType, + context_data: ContextData, + context_text: ContextText, + completion_time: f64, + llm_calls: u32, + prompt_tokens: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ResponseType { + String(String), + Dictionary(HashMap), + Dictionaries(Vec>), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ContextData { + String(String), + DataFrames(Vec), + Dictionary(HashMap), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ContextText { + String(String), + Strings(Vec), + Dictionary(HashMap), +} + +#[derive(Serialize, Deserialize)] +pub struct GlobalSearchResult { + response: ResponseType, + context_data: ContextData, + context_text: ContextText, + completion_time: f64, + llm_calls: i32, + prompt_tokens: i32, + map_responses: Vec, + reduce_context_data: ContextData, + reduce_context_text: ContextText, +} + +struct GlobalSearchLLMCallback { + map_response_contexts: Vec, + map_response_outputs: Vec, +} + +impl GlobalSearchLLMCallback { + pub fn new() -> Self { + GlobalSearchLLMCallback { + map_response_contexts: Vec::new(), + map_response_outputs: Vec::new(), + } + } + + pub fn on_map_response_start(&mut self, map_response_contexts: Vec) { + self.map_response_contexts = map_response_contexts; + } + + pub fn on_map_response_end(&mut self, map_response_outputs: Vec) { + self.map_response_outputs = map_response_outputs; + } +} + +pub struct GlobalSearch { + llm: Box, + context_builder: Box, + token_encoder: Option, + llm_params: Option>, + context_builder_params: Option>, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: HashMap, + reduce_llm_params: HashMap, + semaphore: Semaphore, +} + +impl GlobalSearch { + pub fn new( + llm: Box, + context_builder: Box, + token_encoder: Option, + map_system_prompt: String, + reduce_system_prompt: String, + response_type: String, + allow_general_knowledge: bool, + general_knowledge_inclusion_prompt: String, + json_mode: bool, + callbacks: Option>, + max_data_tokens: usize, + map_llm_params: HashMap, + reduce_llm_params: HashMap, + context_builder_params: Option>, + concurrent_coroutines: usize, + ) -> Self { + let mut map_llm_params = map_llm_params; + + if json_mode { + map_llm_params.insert( + "response_format".to_string(), + serde_json::json!({"type": "json_object"}), + ); + } else { + map_llm_params.remove("response_format"); + } + + let semaphore = Semaphore::new(concurrent_coroutines); + + GlobalSearch { + llm, + context_builder, + token_encoder, + llm_params: None, + context_builder_params, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + semaphore, + } + } + + pub async fn asearch( + &self, + query: String, + conversation_history: Option, + ) -> GlobalSearchResult { + // Step 1: Generate answers for each batch of community short summaries + let start_time = Instant::now(); + let (context_chunks, context_records) = self + .context_builder + .build_context(conversation_history, self.context_builder_params) + .await; + + if let Some(callbacks) = &self.callbacks { + for callback in callbacks { + callback.on_map_response_start(context_chunks); + } + } + + let map_responses: Vec<_> = join_all( + context_chunks + .iter() + .map(|data| self._map_response_single_batch(data, &query, &self.map_llm_params)), + ) + .await; + + if let Some(callbacks) = &self.callbacks { + for callback in callbacks { + callback.on_map_response_end(&map_responses); + } + } + + let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); + let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); + + // Step 2: Combine the intermediate answers from step 2 to generate the final answer + let reduce_response = self + ._reduce_response(&map_responses, &query, self.reduce_llm_params) + .await; + + GlobalSearchResult { + response: reduce_response.response, + context_data: ContextData::Dictionary(context_records), + context_text: ContextText::Strings(context_chunks), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: map_llm_calls + reduce_response.llm_calls, + prompt_tokens: map_prompt_tokens + reduce_response.prompt_tokens, + map_responses, + reduce_context_data: reduce_response.context_data, + reduce_context_text: reduce_response.context_text, + } + } + + async fn _reduce_response( + &self, + map_responses: Vec, + query: &str, + reduce_llm_params: HashMap, + ) -> SearchResult { + let start_time = Instant::now(); + let mut key_points = Vec::new(); + + for (index, response) in map_responses.iter().enumerate() { + if let ResponseType::Dictionaries(response_list) = response.response { + for element in response_list { + if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { + key_points.push((index, answer.clone(), score.clone())); + } + } + } + } + + let filtered_key_points: Vec<_> = key_points + .into_iter() + .filter(|(_, _, score)| score.as_f64().unwrap_or(0.0) > 0.0) + .collect(); + + if filtered_key_points.is_empty() && !self.allow_general_knowledge { + return SearchResult { + response: ResponseType::String("NO_DATA_ANSWER".to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + }; + } + + let mut sorted_key_points = filtered_key_points; + sorted_key_points.sort_by(|a, b| { + b.2.as_f64() + .unwrap_or(0.0) + .partial_cmp(&a.2.as_f64().unwrap_or(0.0)) + .unwrap() + }); + + // TODO: Implement rest of the function + + SearchResult { + response: ResponseType::String("Combined response".to_string()), + context_data: ContextData::String("".to_string()), + context_text: ContextText::String("".to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 0, + prompt_tokens: 0, + } + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/mod.rs new file mode 100644 index 000000000..a12441830 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/mod.rs @@ -0,0 +1 @@ +pub mod global_search; From 6696bb1af32308a92d5034290b65bada5cab91da Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 1 Aug 2024 15:12:24 +0200 Subject: [PATCH 02/12] global search, open ai chat --- Cargo.lock | 197 ++++++++++--- shinkai-libs/shinkai-graphrag/Cargo.toml | 4 +- shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 23 +- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 2 + .../shinkai-graphrag/src/llm/openai.rs | 107 +++++++ .../shinkai-graphrag/src/llm/utils.rs | 7 + .../src/search/global_search.rs | 260 ++++++++++++++---- 7 files changed, 489 insertions(+), 111 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/openai.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/llm/utils.rs diff --git a/Cargo.lock b/Cargo.lock index dd8b827ba..49261bfe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -724,6 +724,15 @@ dependencies = [ "futures-core", ] +[[package]] +name = "async-convert" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d416feee97712e43152cd42874de162b8f9b77295b1c85e5d92725cc8310bae" +dependencies = [ + "async-trait", +] + [[package]] name = "async-executor" version = "1.5.1" @@ -1405,6 +1414,20 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.10", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -1654,6 +1677,17 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "bstr" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +dependencies = [ + "memchr", + "regex-automata 0.4.7", + "serde", +] + [[package]] name = "buf_redux" version = "0.8.4" @@ -2223,7 +2257,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0436149c9f6a1935b13306206c739b1ba84fa81f551b5eb87fc2ca7a13700af" dependencies = [ "clap 4.5.4", - "derive_builder", + "derive_builder 0.12.0", "entities", "memchr", "once_cell", @@ -3162,7 +3196,16 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" dependencies = [ - "derive_builder_macro", + "derive_builder_macro 0.12.0", +] + +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro 0.20.0", ] [[package]] @@ -3177,16 +3220,38 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling 0.20.10", + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "derive_builder_macro" version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" dependencies = [ - "derive_builder_core", + "derive_builder_core 0.12.0", "syn 1.0.109", ] +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core 0.20.0", + "syn 2.0.66", +] + [[package]] name = "derive_more" version = "0.99.18" @@ -3899,6 +3964,38 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b5fb89194fa3cad959b833185b3063ba881dbfc7030680b314250779fb4cc91" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +dependencies = [ + "event-listener 5.2.0", + "pin-project-lite", +] + +[[package]] +name = "eventsource-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +dependencies = [ + "futures-core", + "nom", + "pin-project-lite", +] + [[package]] name = "exr" version = "1.72.0" @@ -3941,6 +4038,16 @@ dependencies = [ "regex", ] +[[package]] +name = "fancy-regex" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fast-float" version = "0.2.0" @@ -7376,28 +7483,6 @@ dependencies = [ "hmac 0.12.1", ] -[[package]] -name = "pcre2" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be55c43ac18044541d58d897e8f4c55157218428953ebd39d86df3ba0286b2b" -dependencies = [ - "libc", - "log 0.4.21", - "pcre2-sys", -] - -[[package]] -name = "pcre2-sys" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "550f5d18fb1b90c20b87e161852c10cde77858c3900c5059b5ad2a1449f11d8a" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "pddl-ish-parser" version = "0.0.4" @@ -9168,6 +9253,12 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" + [[package]] name = "regex-lite" version = "0.1.5" @@ -9295,6 +9386,22 @@ dependencies = [ "winreg 0.52.0", ] +[[package]] +name = "reqwest-eventsource" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" +dependencies = [ + "eventsource-stream", + "futures-core", + "futures-timer", + "mime 0.3.17", + "nom", + "pin-project-lite", + "reqwest 0.12.5", + "thiserror", +] + [[package]] name = "rfc6979" version = "0.3.1" @@ -9552,16 +9659,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "rust_decimal_macros" -version = "1.35.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a05bf7103af0797dbce0667c471946b29b9eaea34652eff67324f360fec027de" -dependencies = [ - "quote 1.0.36", - "rust_decimal", -] - [[package]] name = "rustc-demangle" version = "0.1.23" @@ -9897,6 +9994,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secrecy" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bd1c54ea06cfd2f6b63219704de0b9b4f72dcc2b8fdef820be6cd799780e91e" +dependencies = [ + "serde", + "zeroize", +] + [[package]] name = "security-framework" version = "2.9.2" @@ -10167,12 +10274,14 @@ dependencies = [ name = "shinkai-graphrag" version = "0.1.0" dependencies = [ + "anyhow", + "async-openai", "async-trait", "futures", "polars", "serde", "serde_json", - "tiktoken", + "tiktoken-rs", "tokio", ] @@ -11006,7 +11115,7 @@ checksum = "874dcfa363995604333cf947ae9f751ca3af4522c60886774c4963943b4746b1" dependencies = [ "bincode", "bitflags 1.3.2", - "fancy-regex", + "fancy-regex 0.11.0", "flate2", "fnv", "once_cell", @@ -11369,18 +11478,18 @@ dependencies = [ ] [[package]] -name = "tiktoken" -version = "1.0.1" +name = "tiktoken-rs" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d8d54b2ba3a0f0b14743f9f0e0f6884ee03ad8789f1ce6a4b5650c0d74fad8" +checksum = "c314e7ce51440f9e8f5a497394682a57b7c323d0f4d0a6b1b13c429056e0e234" dependencies = [ "anyhow", "base64 0.21.7", + "bstr", + "fancy-regex 0.12.0", "lazy_static", - "maplit", - "pcre2", - "rust_decimal", - "rust_decimal_macros", + "parking_lot 0.12.1", + "rustc-hash", ] [[package]] diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 591d7d48c..9057f8ea0 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -4,10 +4,12 @@ version = "0.1.0" edition = "2021" [dependencies] +anyhow = "1.0.86" +async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" polars = "0.41.3" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" -tiktoken = "1.0.1" +tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index de33da38b..4c8f0f68e 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -1,7 +1,11 @@ +use std::collections::HashMap; + use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone)] pub struct BaseLLMCallback { - response: Vec, + pub response: Vec, } impl BaseLLMCallback { @@ -14,22 +18,25 @@ impl BaseLLMCallback { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MessageType { + String(String), + Strings(Vec), + Dictionary(Vec>), +} + #[async_trait] pub trait BaseLLM { - async fn generate(&self, messages: Vec, streaming: bool, callbacks: Option>) - -> String; - async fn agenerate( &self, - messages: Vec, + messages: MessageType, streaming: bool, callbacks: Option>, - ) -> String; + llm_params: HashMap, + ) -> anyhow::Result; } #[async_trait] pub trait BaseTextEmbedding { - async fn embed(&self, text: &str) -> Vec; - async fn aembed(&self, text: &str) -> Vec; } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 214bbef7c..00cb1d9e1 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1 +1,3 @@ pub mod llm; +pub mod openai; +pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs new file mode 100644 index 000000000..a0b3986b6 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs @@ -0,0 +1,107 @@ +use std::collections::HashMap; + +use async_openai::{ + config::OpenAIConfig, + types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs}, + Client, +}; + +use super::llm::{BaseLLMCallback, MessageType}; + +pub struct ChatOpenAI { + pub api_key: Option, + pub model: String, + pub max_retries: usize, +} + +impl ChatOpenAI { + pub fn new(api_key: Option, model: String, max_retries: usize) -> Self { + ChatOpenAI { + api_key, + model, + max_retries, + } + } + + pub async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: HashMap, + ) -> anyhow::Result { + let mut retry_count = 0; + + loop { + match self + ._agenerate(messages.clone(), streaming, callbacks.clone(), llm_params.clone()) + .await + { + Ok(response) => return Ok(response), + Err(e) => { + if retry_count < self.max_retries { + retry_count += 1; + continue; + } + return Err(e); + } + } + } + } + + async fn _agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: HashMap, + ) -> anyhow::Result { + let client = match &self.api_key { + Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), + None => Client::new(), + }; + + let messages = match messages { + MessageType::String(message) => vec![message], + MessageType::Strings(messages) => messages, + MessageType::Dictionary(messages) => { + let messages = messages + .iter() + .map(|message_map| { + message_map + .iter() + .map(|(key, value)| format!("{}: {}", key, value)) + .collect::>() + .join("\n") + }) + .collect(); + messages + } + }; + + let request_messages = messages + .into_iter() + .map(|m| ChatCompletionRequestSystemMessageArgs::default().content(m).build()) + .collect::>(); + + let request_messages: Result, _> = request_messages.into_iter().collect(); + let request_messages = request_messages?; + let request_messages = request_messages + .into_iter() + .map(|m| Into::::into(m.clone())) + .collect::>(); + + let request = CreateChatCompletionRequestArgs::default() + .model(self.model.clone()) + .messages(request_messages) + .build()?; + + let response = client.chat().create(request).await?; + + if let Some(choice) = response.choices.get(0) { + return Ok(choice.message.content.clone().unwrap_or_default()); + } + + return Ok(String::new()); + } +} diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs new file mode 100644 index 000000000..a6b4dfc54 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs @@ -0,0 +1,7 @@ +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; + +pub fn num_tokens(text: &str, token_encoder: Option) -> usize { + let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_ordinary(text).len() +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 0a18ec55a..0dde10c0c 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -1,62 +1,72 @@ use futures::future::join_all; //use polars::frame::DataFrame; use serde::{Deserialize, Serialize}; +use serde_json::Value; use std::collections::HashMap; use std::time::Instant; -use tiktoken::encoding::Encoding; -use tokio::sync::Semaphore; +use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; -use crate::llm::llm::BaseLLM; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; +use crate::llm::utils::num_tokens; // TODO: Serialize and Deserialize polars::frame::DataFrame type DataFrame = Vec; -#[derive(Debug, Serialize, Deserialize)] -struct SearchResult { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { response: ResponseType, context_data: ContextData, context_text: ContextText, completion_time: f64, - llm_calls: u32, - prompt_tokens: u32, + llm_calls: usize, + prompt_tokens: usize, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ResponseType { String(String), Dictionary(HashMap), Dictionaries(Vec>), + KeyPoints(Vec), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ContextData { String(String), DataFrames(Vec), Dictionary(HashMap), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum ContextText { String(String), Strings(Vec), Dictionary(HashMap), } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KeyPoint { + answer: String, + score: i32, +} + #[derive(Serialize, Deserialize)] pub struct GlobalSearchResult { response: ResponseType, context_data: ContextData, context_text: ContextText, completion_time: f64, - llm_calls: i32, - prompt_tokens: i32, + llm_calls: usize, + prompt_tokens: usize, map_responses: Vec, reduce_context_data: ContextData, reduce_context_text: ContextText, } -struct GlobalSearchLLMCallback { +#[derive(Debug, Clone)] +pub struct GlobalSearchLLMCallback { + response: Vec, map_response_contexts: Vec, map_response_outputs: Vec, } @@ -64,6 +74,7 @@ struct GlobalSearchLLMCallback { impl GlobalSearchLLMCallback { pub fn new() -> Self { GlobalSearchLLMCallback { + response: Vec::new(), map_response_contexts: Vec::new(), map_response_outputs: Vec::new(), } @@ -81,10 +92,8 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, context_builder: Box, - token_encoder: Option, - llm_params: Option>, + token_encoder: Option, context_builder_params: Option>, - map_system_prompt: String, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -93,15 +102,13 @@ pub struct GlobalSearch { max_data_tokens: usize, map_llm_params: HashMap, reduce_llm_params: HashMap, - semaphore: Semaphore, } impl GlobalSearch { pub fn new( llm: Box, context_builder: Box, - token_encoder: Option, - map_system_prompt: String, + token_encoder: Option, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -112,7 +119,6 @@ impl GlobalSearch { map_llm_params: HashMap, reduce_llm_params: HashMap, context_builder_params: Option>, - concurrent_coroutines: usize, ) -> Self { let mut map_llm_params = map_llm_params; @@ -125,15 +131,11 @@ impl GlobalSearch { map_llm_params.remove("response_format"); } - let semaphore = Semaphore::new(concurrent_coroutines); - GlobalSearch { llm, context_builder, token_encoder, - llm_params: None, context_builder_params, - map_system_prompt, reduce_system_prompt, response_type, allow_general_knowledge, @@ -142,7 +144,6 @@ impl GlobalSearch { max_data_tokens, map_llm_params, reduce_llm_params, - semaphore, } } @@ -150,42 +151,59 @@ impl GlobalSearch { &self, query: String, conversation_history: Option, - ) -> GlobalSearchResult { + ) -> anyhow::Result { // Step 1: Generate answers for each batch of community short summaries let start_time = Instant::now(); let (context_chunks, context_records) = self .context_builder - .build_context(conversation_history, self.context_builder_params) + .build_context(conversation_history, self.context_builder_params.clone()) .await; - if let Some(callbacks) = &self.callbacks { - for callback in callbacks { - callback.on_map_response_start(context_chunks); + let mut callbacks = match &self.callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_start(context_chunks.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) } - } + None => None, + }; let map_responses: Vec<_> = join_all( context_chunks .iter() - .map(|data| self._map_response_single_batch(data, &query, &self.map_llm_params)), + .map(|data| self._map_response_single_batch(data, &query, self.map_llm_params.clone())), ) .await; - if let Some(callbacks) = &self.callbacks { - for callback in callbacks { - callback.on_map_response_end(&map_responses); + let map_responses: Result, _> = map_responses.into_iter().collect(); + let map_responses = map_responses?; + + callbacks = match &callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + let mut callback = callback.clone(); + callback.on_map_response_end(map_responses.clone()); + llm_callbacks.push(callback); + } + Some(llm_callbacks) } - } + None => None, + }; let map_llm_calls: usize = map_responses.iter().map(|response| response.llm_calls).sum(); let map_prompt_tokens: usize = map_responses.iter().map(|response| response.prompt_tokens).sum(); // Step 2: Combine the intermediate answers from step 2 to generate the final answer let reduce_response = self - ._reduce_response(&map_responses, &query, self.reduce_llm_params) - .await; + ._reduce_response(map_responses.clone(), &query, callbacks, self.reduce_llm_params.clone()) + .await?; - GlobalSearchResult { + Ok(GlobalSearchResult { response: reduce_response.response, context_data: ContextData::Dictionary(context_records), context_text: ContextText::Strings(context_chunks), @@ -195,61 +213,187 @@ impl GlobalSearch { map_responses, reduce_context_data: reduce_response.context_data, reduce_context_text: reduce_response.context_text, + }) + } + + async fn _map_response_single_batch( + &self, + context_data: &str, + query: &str, + llm_params: HashMap, + ) -> anyhow::Result { + let start_time = Instant::now(); + let search_prompt = String::new(); + let mut search_messages = Vec::new(); + search_messages.push(HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ])); + search_messages.push(HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ])); + + let search_response = self + .llm + .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .await?; + + let processed_response = self.parse_search_response(&search_response); + + Ok(SearchResult { + response: ResponseType::KeyPoints(processed_response), + context_data: ContextData::String(context_data.to_string()), + context_text: ContextText::String(context_data.to_string()), + completion_time: start_time.elapsed().as_secs_f64(), + llm_calls: 1, + prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + }) + } + + fn parse_search_response(&self, search_response: &str) -> Vec { + let parsed_elements: Value = serde_json::from_str(search_response).unwrap_or_default(); + + if let Some(points) = parsed_elements.get("points") { + if let Some(points) = points.as_array() { + return points + .iter() + .map(|element| KeyPoint { + answer: element + .get("description") + .unwrap_or(&Value::String("".to_string())) + .to_string(), + score: element + .get("score") + .unwrap_or(&Value::Number(serde_json::Number::from(0))) + .as_i64() + .unwrap_or(0) as i32, + }) + .collect::>(); + } } + + Vec::new() } async fn _reduce_response( &self, map_responses: Vec, query: &str, + callbacks: Option>, reduce_llm_params: HashMap, - ) -> SearchResult { + ) -> anyhow::Result { let start_time = Instant::now(); - let mut key_points = Vec::new(); + let mut key_points: Vec> = Vec::new(); for (index, response) in map_responses.iter().enumerate() { - if let ResponseType::Dictionaries(response_list) = response.response { + if let ResponseType::Dictionaries(response_list) = &response.response { for element in response_list { if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { - key_points.push((index, answer.clone(), score.clone())); + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), answer.to_string()); + point.insert("score".to_string(), score.to_string()); + key_points.push(point); } } } } - let filtered_key_points: Vec<_> = key_points + let filtered_key_points: Vec> = key_points .into_iter() - .filter(|(_, _, score)| score.as_f64().unwrap_or(0.0) > 0.0) + .filter(|point| point.get("score").unwrap().parse::().unwrap() > 0) .collect(); if filtered_key_points.is_empty() && !self.allow_general_knowledge { - return SearchResult { + return Ok(SearchResult { response: ResponseType::String("NO_DATA_ANSWER".to_string()), context_data: ContextData::String("".to_string()), context_text: ContextText::String("".to_string()), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 0, prompt_tokens: 0, - }; + }); } let mut sorted_key_points = filtered_key_points; sorted_key_points.sort_by(|a, b| { - b.2.as_f64() - .unwrap_or(0.0) - .partial_cmp(&a.2.as_f64().unwrap_or(0.0)) + b.get("score") + .unwrap() + .parse::() .unwrap() + .cmp(&a.get("score").unwrap().parse::().unwrap()) }); - // TODO: Implement rest of the function + let mut data: Vec = Vec::new(); + let mut total_tokens = 0; + for point in sorted_key_points { + let mut formatted_response_data: Vec = Vec::new(); + formatted_response_data.push(format!("----Analyst {}----", point.get("analyst").unwrap())); + formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); + formatted_response_data.push(point.get("answer").unwrap().to_string()); + let formatted_response_text = formatted_response_data.join("\n"); + if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { + break; + } + data.push(formatted_response_text.clone()); + total_tokens += num_tokens(&formatted_response_text, self.token_encoder); + } + let text_data = data.join("\n\n"); - SearchResult { - response: ResponseType::String("Combined response".to_string()), - context_data: ContextData::String("".to_string()), - context_text: ContextText::String("".to_string()), + let search_prompt = format!( + "{}\n{}", + self.reduce_system_prompt + .replace("{report_data}", &text_data) + .replace("{response_type}", &self.response_type), + if self.allow_general_knowledge { + self.general_knowledge_inclusion_prompt.clone() + } else { + String::new() + } + ); + + let search_messages = vec![ + HashMap::from([ + ("role".to_string(), "system".to_string()), + ("content".to_string(), search_prompt.clone()), + ]), + HashMap::from([ + ("role".to_string(), "user".to_string()), + ("content".to_string(), query.to_string()), + ]), + ]; + + let llm_callbacks = match callbacks { + Some(callbacks) => { + let mut llm_callbacks = Vec::new(); + for callback in callbacks { + llm_callbacks.push(BaseLLMCallback { + response: callback.response.clone(), + }); + } + Some(llm_callbacks) + } + None => None, + }; + + let search_response = self + .llm + .agenerate( + MessageType::Dictionary(search_messages), + true, + llm_callbacks, + reduce_llm_params, + ) + .await?; + + Ok(SearchResult { + response: ResponseType::String(search_response), + context_data: ContextData::String(text_data.clone()), + context_text: ContextText::String(text_data), completion_time: start_time.elapsed().as_secs_f64(), - llm_calls: 0, - prompt_tokens: 0, - } + llm_calls: 1, + prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + }) } } From 26115547c65a15c14b86c951ae0045a6c495e8e4 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 2 Aug 2024 14:19:04 +0200 Subject: [PATCH 03/12] read indexer entities, indexer reports --- Cargo.lock | 77 +++++++ shinkai-libs/shinkai-graphrag/Cargo.toml | 3 +- .../src/context_builder/indexer_entities.rs | 207 ++++++++++++++++++ .../src/context_builder/indexer_reports.rs | 135 ++++++++++++ .../src/context_builder/mod.rs | 2 + .../shinkai-graphrag/src/llm/openai.rs | 6 +- 6 files changed, 426 insertions(+), 4 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs diff --git a/Cargo.lock b/Cargo.lock index 49261bfe2..483ef12e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -123,6 +123,21 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.18" @@ -856,6 +871,28 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2 1.0.84", + "quote 1.0.36", + "syn 2.0.66", +] + [[package]] name = "async-task" version = "4.4.0" @@ -1667,6 +1704,27 @@ dependencies = [ "syn_derive", ] +[[package]] +name = "brotli" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19483b140a7ac7174d34b5a581b406c64f84da5409d3e09cf4fff604f9270e67" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bs58" version = "0.5.0" @@ -7396,6 +7454,10 @@ name = "parquet-format-safe" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1131c54b167dd4e4799ce762e1ab01549ebb94d5bdd13e6ec1b467491c378e1f" +dependencies = [ + "async-trait", + "futures", +] [[package]] name = "parse-zoneinfo" @@ -7963,6 +8025,7 @@ dependencies = [ "ethnum", "fast-float", "foreign_vec", + "futures", "getrandom 0.2.10", "hashbrown 0.14.5", "itoa 1.0.9", @@ -8078,10 +8141,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d7363cd14e4696a28b334a56bd11013ff49cc96064818ab3f91a126e453462d" dependencies = [ "ahash 0.8.11", + "async-trait", "atoi_simd", "bytes", "chrono", "fast-float", + "futures", "home", "itoa 1.0.9", "memchr", @@ -8092,6 +8157,7 @@ dependencies = [ "polars-arrow", "polars-core", "polars-error", + "polars-parquet", "polars-time", "polars-utils", "rayon", @@ -8099,6 +8165,8 @@ dependencies = [ "ryu", "simdutf8", "smartstring", + "tokio", + "tokio-util 0.7.11", ] [[package]] @@ -8182,8 +8250,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b35b2592a2e7ef7ce9942dc2120dc4576142626c0e661668e4c6b805042e461" dependencies = [ "ahash 0.8.11", + "async-stream", "base64 0.22.1", + "brotli", "ethnum", + "flate2", + "futures", + "lz4", "num-traits", "parquet-format-safe", "polars-arrow", @@ -8191,7 +8264,9 @@ dependencies = [ "polars-error", "polars-utils", "simdutf8", + "snap", "streaming-decompression", + "zstd 0.13.2", ] [[package]] @@ -8237,6 +8312,7 @@ dependencies = [ "polars-core", "polars-io", "polars-ops", + "polars-parquet", "polars-time", "polars-utils", "rayon", @@ -10279,6 +10355,7 @@ dependencies = [ "async-trait", "futures", "polars", + "polars-lazy", "serde", "serde_json", "tiktoken-rs", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 9057f8ea0..5cd0f1549 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -8,7 +8,8 @@ anyhow = "1.0.86" async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" -polars = "0.41.3" +polars = { version = "0.41.3", features = ["lazy", "parquet"] } +polars-lazy = "0.41.3" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs new file mode 100644 index 000000000..0265a05f7 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -0,0 +1,207 @@ +use std::collections::HashMap; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +use super::indexer_reports::filter_under_community_level; + +pub fn read_indexer_entities( + final_nodes: &DataFrame, + final_entities: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let mut entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_df = entity_df.rename("title", "name")?.rename("degree", "rank")?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::Int32)) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("rank").cast(DataType::Int32)) + .collect()?; + + let entity_embedding_df = final_entities.clone(); + + let entity_df = entity_df + .clone() + .lazy() + .group_by([col("name"), col("rank")]) + .agg([col("community").max()]) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::String)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .join( + entity_embedding_df.clone().lazy(), + [col("name")], + [col("name")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .filter(len().over([col("name")]).gt(lit(1))) + .collect()?; + + let entities = read_entities( + &entity_df, + "id", + Some("human_readable_id"), + "name", + Some("type"), + Some("description"), + None, + Some("description_embedding"), + None, + Some("community"), + Some("text_unit_ids"), + None, + Some("rank"), + )?; + + Ok(entities) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Entity { + pub id: String, + pub short_id: Option, + pub title: String, + pub entity_type: Option, + pub description: Option, + pub description_embedding: Option>, + pub name_embedding: Option>, + pub graph_embedding: Option>, + pub community_ids: Option>, + pub text_unit_ids: Option>, + pub document_ids: Option>, + pub rank: Option, + pub attributes: Option>, +} + +pub fn read_entities( + df: &DataFrame, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + type_col: Option<&str>, + description_col: Option<&str>, + name_embedding_col: Option<&str>, + description_embedding_col: Option<&str>, + graph_embedding_col: Option<&str>, + community_col: Option<&str>, + text_unit_ids_col: Option<&str>, + document_ids_col: Option<&str>, + rank_col: Option<&str>, + // attributes_cols: Option>, +) -> anyhow::Result> { + let column_names = [ + id_col, + short_id_col.unwrap_or("short_id"), + title_col, + type_col.unwrap_or("type"), + description_col.unwrap_or("description"), + name_embedding_col.unwrap_or("name_embedding"), + description_embedding_col.unwrap_or("description_embedding"), + graph_embedding_col.unwrap_or("graph_embedding"), + community_col.unwrap_or("community_ids"), + text_unit_ids_col.unwrap_or("text_unit_ids"), + document_ids_col.unwrap_or("document_ids"), + rank_col.unwrap_or("degree"), + ]; + + let mut df = df.clone(); + df.as_single_chunk_par(); + let mut iters = df.columns(column_names)?.iter().map(|s| s.iter()).collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value.to_string()); + } + } + rows.push(row_values); + } + + let mut entities = Vec::new(); + for row in rows { + let report = Entity { + id: row.get(0).unwrap_or(&String::new()).to_string(), + short_id: Some(row.get(1).unwrap_or(&String::new()).to_string()), + title: row.get(2).unwrap_or(&String::new()).to_string(), + entity_type: Some(row.get(3).unwrap_or(&String::new()).to_string()), + description: Some(row.get(4).unwrap_or(&String::new()).to_string()), + name_embedding: Some( + row.get(5) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + description_embedding: Some( + row.get(6) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + graph_embedding: Some( + row.get(7) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.parse::().unwrap_or(0.0)) + .collect(), + ), + community_ids: Some( + row.get(8) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + text_unit_ids: Some( + row.get(9) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + document_ids: Some( + row.get(10) + .unwrap_or(&String::new()) + .split(',') + .map(|v| v.to_string()) + .collect(), + ), + rank: Some(row.get(11).and_then(|v| v.parse::().ok()).unwrap_or(0)), + attributes: None, + }; + entities.push(report); + } + + Ok(entities) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs new file mode 100644 index 000000000..07811ac8b --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -0,0 +1,135 @@ +use std::collections::HashMap; + +use polars::prelude::*; +use polars_lazy::dsl::col; +use serde::{Deserialize, Serialize}; + +pub fn read_indexer_reports( + final_community_reports: &DataFrame, + final_nodes: &DataFrame, + community_level: u32, +) -> anyhow::Result> { + let entity_df = final_nodes.clone(); + let entity_df = filter_under_community_level(&entity_df, community_level)?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").fill_null(lit(-1))) + .collect()?; + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::Int32)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .with_column(col("community").cast(DataType::String)) + .collect()?; + + let entity_df = entity_df + .clone() + .lazy() + .group_by([col("title")]) + .agg([col("community").max()]) + .collect()?; + + let filtered_community_df = entity_df + .clone() + .lazy() + .filter(len().over([col("community")]).gt(lit(1))) + .collect()?; + + let report_df = final_community_reports.clone(); + let report_df = filter_under_community_level(&report_df, community_level)?; + + let report_df = report_df + .clone() + .lazy() + .join( + filtered_community_df.clone().lazy(), + [col("community")], + [col("community")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + + let reports = read_community_reports(&report_df, "community", Some("community"), None, None)?; + Ok(reports) +} + +pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { + let mask = df.column("level")?.i32()?.lt_eq(community_level); + let result = df.filter(&mask)?; + + Ok(result) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct CommunityReport { + pub id: String, + pub short_id: Option, + pub title: String, + pub community_id: String, + pub summary: String, + pub full_content: String, + pub rank: Option, + pub summary_embedding: Option>, + pub full_content_embedding: Option>, + pub attributes: Option>, +} + +pub fn read_community_reports( + df: &DataFrame, + _id_col: &str, + _short_id_col: Option<&str>, + // title_col: &str, + // community_col: &str, + // summary_col: &str, + // content_col: &str, + // rank_col: Option<&str>, + _summary_embedding_col: Option<&str>, + _content_embedding_col: Option<&str>, + // attributes_cols: Option<&[&str]>, +) -> anyhow::Result> { + let mut df = df.clone(); + df.as_single_chunk_par(); + let mut iters = df + .columns(["community", "title", "summary", "full_content", "rank"])? + .iter() + .map(|s| s.iter()) + .collect::>(); + + let mut rows = Vec::new(); + for _row in 0..df.height() { + let mut row_values = Vec::new(); + for iter in &mut iters { + let value = iter.next(); + if let Some(value) = value { + row_values.push(value.to_string()); + } + } + rows.push(row_values); + } + + let mut reports = Vec::new(); + for row in rows { + let report = CommunityReport { + id: row.get(0).unwrap_or(&String::new()).to_string(), + short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), + title: row.get(1).unwrap_or(&String::new()).to_string(), + community_id: row.get(0).unwrap_or(&String::new()).to_string(), + summary: row.get(3).unwrap_or(&String::new()).to_string(), + full_content: row.get(4).unwrap_or(&String::new()).to_string(), + rank: Some(row.get(5).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + summary_embedding: None, + full_content_embedding: None, + attributes: None, + }; + reports.push(report); + } + + Ok(reports) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 709d766d9..08173d23d 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1 +1,3 @@ pub mod context_builder; +pub mod indexer_entities; +pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs index a0b3986b6..644f1101f 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/openai.rs @@ -52,9 +52,9 @@ impl ChatOpenAI { async fn _agenerate( &self, messages: MessageType, - streaming: bool, - callbacks: Option>, - llm_params: HashMap, + _streaming: bool, + _callbacks: Option>, + _llm_params: HashMap, ) -> anyhow::Result { let client = match &self.api_key { Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), From 36c0c43514e5657c228ac15998cf19ebe84ec8f4 Mon Sep 17 00:00:00 2001 From: benolt Date: Wed, 7 Aug 2024 16:40:27 +0200 Subject: [PATCH 04/12] build community context batch 1 --- Cargo.lock | 65 +++- shinkai-libs/shinkai-graphrag/Cargo.toml | 1 + .../src/context_builder/community_context.rs | 326 ++++++++++++++++++ .../src/context_builder/indexer_entities.rs | 2 +- .../src/context_builder/indexer_reports.rs | 2 +- .../src/context_builder/mod.rs | 1 + .../src/search/global_search.rs | 7 +- 7 files changed, 383 insertions(+), 21 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs diff --git a/Cargo.lock b/Cargo.lock index 483ef12e9..33dd7b2a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,6 +806,43 @@ dependencies = [ "event-listener 2.5.3", ] +[[package]] +name = "async-lock" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +dependencies = [ + "event-listener 5.2.0", + "event-listener-strategy", + "pin-project-lite", +] + +[[package]] +name = "async-openai" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e5ff98f9e7c605df4c88783a0439d1dc667ce86bd79e99d4164f8b0c05ccc" +dependencies = [ + "async-convert", + "backoff", + "base64 0.22.1", + "bytes", + "derive_builder 0.20.0", + "eventsource-stream", + "futures", + "rand 0.8.5", + "reqwest 0.12.5", + "reqwest-eventsource", + "secrecy", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util 0.7.11", + "tracing", +] + [[package]] name = "async-priority-channel" version = "0.2.0" @@ -6377,12 +6414,6 @@ dependencies = [ "libc", ] -[[package]] -name = "maplit" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" - [[package]] name = "markup5ever" version = "0.10.1" @@ -9329,12 +9360,6 @@ dependencies = [ "regex-syntax 0.8.2", ] -[[package]] -name = "regex-automata" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" - [[package]] name = "regex-lite" version = "0.1.5" @@ -9437,6 +9462,7 @@ dependencies = [ "js-sys", "log 0.4.21", "mime 0.3.17", + "mime_guess 2.0.4", "once_cell", "percent-encoding 2.3.1", "pin-project-lite", @@ -10217,9 +10243,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" dependencies = [ "base64 0.22.1", "chrono", @@ -10235,9 +10261,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" dependencies = [ "darling 0.20.10", "proc-macro2", @@ -10356,6 +10382,7 @@ dependencies = [ "futures", "polars", "polars-lazy", + "rand 0.8.5", "serde", "serde_json", "tiktoken-rs", @@ -10873,6 +10900,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.4.9" diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 5cd0f1549..5a5ceafb7 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -10,6 +10,7 @@ async-trait = "0.1.74" futures = "0.3.30" polars = { version = "0.41.3", features = ["lazy", "parquet"] } polars-lazy = "0.41.3" +rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs new file mode 100644 index 000000000..3b8350a3d --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -0,0 +1,326 @@ +use std::{ + collections::HashMap, + io::{Cursor, Read}, +}; + +use polars::{ + frame::DataFrame, + io::SerWriter, + prelude::{col, CsvWriter, DataType, IntoLazy, NamedFrom, SortMultipleOptions}, + series::Series, +}; +use rand::prelude::SliceRandom; +use tiktoken_rs::tokenizer::Tokenizer; + +use crate::llm::utils::num_tokens; + +use super::{context_builder::ConversationHistory, indexer_entities::Entity, indexer_reports::CommunityReport}; + +pub struct GlobalCommunityContext { + community_reports: Vec, + entities: Option>, + token_encoder: Option, + random_state: i32, +} + +impl GlobalCommunityContext { + pub fn new( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + random_state: Option, + ) -> Self { + Self { + community_reports, + entities, + token_encoder, + random_state: random_state.unwrap_or(86), + } + } + + pub async fn build_context( + &self, + conversation_history: Option, + context_builder_params: Option>, + ) -> (Vec, HashMap) { + (vec![], HashMap::new()) + } +} + +pub fn build_community_context( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: i32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: i32, + single_batch: bool, + context_name: &str, + random_state: i32, +) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.is_some() + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .clone() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = _compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports + .iter() + .filter(|&report| _is_included(report)) + .cloned() + .collect(); + + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); + } + + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); + } + + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch_text = String::new(); + let mut batch_tokens = 0; + let mut batch_records: Vec> = Vec::new(); + + let mut _init_batch = || { + batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + batch_tokens = num_tokens(&batch_text, token_encoder); + batch_records = Vec::new(); + }; + + let _cut_batch = |batch_records: Vec>, header: Vec| -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.is_some() { + Some(community_weight_name) + } else { + None + }; + let rank_column = if include_community_rank { + Some(community_rank_name) + } else { + None + }; + + let mut record_df = _convert_report_context_to_df(batch_records, header, weight_column, rank_column)?; + if record_df.is_empty() { + return Ok(()); + } + + let mut buffer = Cursor::new(Vec::new()); + CsvWriter::new(buffer.clone()).finish(&mut record_df).unwrap(); + + let mut current_context_text = String::new(); + buffer.read_to_string(&mut current_context_text)?; + + all_context_text.push(current_context_text); + all_context_records.push(record_df); + + Ok(()) + }; + + _init_batch(); + + Ok((vec![], HashMap::new())) +} + +fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, +) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports.clone(); + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for mut report in community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + } + community_reports +} + +fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, +) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); + } + + let mut data_series = Vec::new(); + for (header, records) in header.iter().zip(context_records.iter()) { + let series = Series::new(header, records); + data_series.push(series); + } + + let record_df = DataFrame::new(data_series)?; + + return _rank_report_context(record_df, weight_column, rank_column); +} + +fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, +) -> anyhow::Result { + let weight_column = weight_column.unwrap_or("occurrence weight"); + let rank_column = rank_column.unwrap_or("rank"); + + let mut rank_attributes = Vec::new(); + rank_attributes.push(weight_column); + let report_df = report_df + .clone() + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + + rank_attributes.push(rank_column); + let report_df = report_df + .clone() + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + + let report_df = report_df + .clone() + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + + Ok(report_df) +} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 0265a05f7..337831da0 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -83,7 +83,7 @@ pub fn read_indexer_entities( Ok(entities) } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct Entity { pub id: String, pub short_id: Option, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index 07811ac8b..fcfca58a6 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -67,7 +67,7 @@ pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> any Ok(result) } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub struct CommunityReport { pub id: String, pub short_id: Option, diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs index 08173d23d..0abed5320 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/mod.rs @@ -1,3 +1,4 @@ +pub mod community_context; pub mod context_builder; pub mod indexer_entities; pub mod indexer_reports; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 0dde10c0c..55636214d 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -6,7 +6,8 @@ use std::collections::HashMap; use std::time::Instant; use tiktoken_rs::tokenizer::Tokenizer; -use crate::context_builder::context_builder::{ConversationHistory, GlobalContextBuilder}; +use crate::context_builder::community_context::GlobalCommunityContext; +use crate::context_builder::context_builder::ConversationHistory; use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; use crate::llm::utils::num_tokens; @@ -91,7 +92,7 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, - context_builder: Box, + context_builder: GlobalCommunityContext, token_encoder: Option, context_builder_params: Option>, reduce_system_prompt: String, @@ -107,7 +108,7 @@ pub struct GlobalSearch { impl GlobalSearch { pub fn new( llm: Box, - context_builder: Box, + context_builder: GlobalCommunityContext, token_encoder: Option, reduce_system_prompt: String, response_type: String, From 7d17d93ce7a32c9f217a059a3db7553dd2073205 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 8 Aug 2024 17:27:05 +0200 Subject: [PATCH 05/12] build community context, global search test --- shinkai-libs/shinkai-graphrag/.gitignore | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 7 +- .../src/context_builder/community_context.rs | 584 +++++++++++------- .../src/context_builder/context_builder.rs | 31 +- .../src/context_builder/indexer_entities.rs | 29 +- .../src/context_builder/indexer_reports.rs | 8 +- shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 9 +- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 - .../src/search/global_search.rs | 88 ++- .../tests/it/global_search_tests.rs | 109 ++++ .../shinkai-graphrag/tests/it/utils/mod.rs | 1 + .../{src/llm => tests/it/utils}/openai.rs | 45 +- shinkai-libs/shinkai-graphrag/tests/it_mod.rs | 4 + 13 files changed, 602 insertions(+), 315 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/.gitignore create mode 100644 shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs create mode 100644 shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs rename shinkai-libs/shinkai-graphrag/{src/llm => tests/it/utils}/openai.rs (70%) create mode 100644 shinkai-libs/shinkai-graphrag/tests/it_mod.rs diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore new file mode 100644 index 000000000..122af2cf4 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -0,0 +1 @@ +dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 5a5ceafb7..378f20e4f 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -5,13 +5,16 @@ edition = "2021" [dependencies] anyhow = "1.0.86" -async-openai = "0.23.4" async-trait = "0.1.74" futures = "0.3.30" -polars = { version = "0.41.3", features = ["lazy", "parquet"] } +polars = { version = "0.41.3", features = ["dtype-struct", "lazy", "parquet"] } polars-lazy = "0.41.3" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" tiktoken-rs = "0.5.9" +tokio = { version = "1.36", features = ["full"] } + +[dev-dependencies] +async-openai = "0.23.4" tokio = { version = "1.36", features = ["full"] } \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index 3b8350a3d..e46219021 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -6,7 +6,7 @@ use std::{ use polars::{ frame::DataFrame, io::SerWriter, - prelude::{col, CsvWriter, DataType, IntoLazy, NamedFrom, SortMultipleOptions}, + prelude::{col, concat, CsvWriter, DataType, IntoLazy, LazyFrame, NamedFrom, SortMultipleOptions, UnionArgs}, series::Series, }; use rand::prelude::SliceRandom; @@ -14,13 +14,12 @@ use tiktoken_rs::tokenizer::Tokenizer; use crate::llm::utils::num_tokens; -use super::{context_builder::ConversationHistory, indexer_entities::Entity, indexer_reports::CommunityReport}; +use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; pub struct GlobalCommunityContext { community_reports: Vec, entities: Option>, token_encoder: Option, - random_state: i32, } impl GlobalCommunityContext { @@ -28,157 +27,346 @@ impl GlobalCommunityContext { community_reports: Vec, entities: Option>, token_encoder: Option, - random_state: Option, ) -> Self { Self { community_reports, entities, token_encoder, - random_state: random_state.unwrap_or(86), } } pub async fn build_context( &self, - conversation_history: Option, - context_builder_params: Option>, - ) -> (Vec, HashMap) { - (vec![], HashMap::new()) + context_builder_params: ContextBuilderParams, + ) -> anyhow::Result<(Vec, HashMap)> { + let ContextBuilderParams { + use_community_summary, + column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + community_rank_name, + include_community_weight, + community_weight_name, + normalize_community_weight, + max_tokens, + context_name, + } = context_builder_params; + + let (community_context, community_context_data) = CommunityContext::build_community_context( + self.community_reports.clone(), + self.entities.clone(), + self.token_encoder.clone(), + use_community_summary, + &column_delimiter, + shuffle_data, + include_community_rank, + min_community_rank, + &community_rank_name, + include_community_weight, + &community_weight_name, + normalize_community_weight, + max_tokens, + false, + &context_name, + )?; + + let final_context = community_context; + let final_context_data = community_context_data; + + Ok((final_context, final_context_data)) } } -pub fn build_community_context( - community_reports: Vec, - entities: Option>, - token_encoder: Option, - use_community_summary: bool, - column_delimiter: &str, - shuffle_data: bool, - include_community_rank: bool, - min_community_rank: i32, - community_rank_name: &str, - include_community_weight: bool, - community_weight_name: &str, - normalize_community_weight: bool, - max_tokens: i32, - single_batch: bool, - context_name: &str, - random_state: i32, -) -> anyhow::Result<(Vec, HashMap)> { - let _is_included = |report: &CommunityReport| -> bool { - report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() - }; - - let _get_header = |attributes: Vec| -> Vec { - let mut header = vec!["id".to_string(), "title".to_string()]; - let mut filtered_attributes: Vec = attributes +pub struct CommunityContext {} + +impl CommunityContext { + pub fn build_community_context( + community_reports: Vec, + entities: Option>, + token_encoder: Option, + use_community_summary: bool, + column_delimiter: &str, + shuffle_data: bool, + include_community_rank: bool, + min_community_rank: u32, + community_rank_name: &str, + include_community_weight: bool, + community_weight_name: &str, + normalize_community_weight: bool, + max_tokens: usize, + single_batch: bool, + context_name: &str, + ) -> anyhow::Result<(Vec, HashMap)> { + let _is_included = |report: &CommunityReport| -> bool { + report.rank.is_some() && report.rank.unwrap() >= min_community_rank.into() + }; + + let _get_header = |attributes: Vec| -> Vec { + let mut header = vec!["id".to_string(), "title".to_string()]; + let mut filtered_attributes: Vec = attributes + .iter() + .filter(|&col| !header.contains(&col.to_string())) + .cloned() + .collect(); + + if !include_community_weight { + filtered_attributes.retain(|col| col != community_weight_name); + } + + header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); + header.push(if use_community_summary { + "summary".to_string() + } else { + "content".to_string() + }); + + if include_community_rank { + header.push(community_rank_name.to_string()); + } + + header + }; + + let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { + let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + + for field in attributes { + let value = report + .attributes + .as_ref() + .and_then(|attrs| attrs.get(field)) + .cloned() + .unwrap_or_default(); + context.push(value); + } + + context.push(if use_community_summary { + report.summary.clone() + } else { + report.full_content.clone() + }); + + if include_community_rank { + context.push(report.rank.unwrap_or_default().to_string()); + } + + let result = context.join(column_delimiter) + "\n"; + (result, context) + }; + + let compute_community_weights = entities.is_some() + && !community_reports.is_empty() + && include_community_weight + && (community_reports[0].attributes.is_none() + || !community_reports[0] + .attributes + .clone() + .unwrap() + .contains_key(community_weight_name)); + + let mut community_reports = community_reports; + if compute_community_weights { + community_reports = Self::_compute_community_weights( + community_reports, + entities.clone(), + community_weight_name, + normalize_community_weight, + ); + } + + let mut selected_reports: Vec = community_reports .iter() - .filter(|&col| !header.contains(&col.to_string())) + .filter(|&report| _is_included(report)) .cloned() .collect(); - if !include_community_weight { - filtered_attributes.retain(|col| col != community_weight_name); + if selected_reports.is_empty() { + return Ok((Vec::new(), HashMap::new())); } - header.extend(filtered_attributes.into_iter().map(|s| s.to_string())); - header.push(if use_community_summary { - "summary".to_string() - } else { - "content".to_string() - }); - - if include_community_rank { - header.push(community_rank_name.to_string()); + if shuffle_data { + let mut rng = rand::thread_rng(); + selected_reports.shuffle(&mut rng); } - header - }; + let attributes = if let Some(attributes) = &community_reports[0].attributes { + attributes.keys().cloned().collect::>() + } else { + Vec::new() + }; + + let header = _get_header(attributes.clone()); + let mut all_context_text: Vec = Vec::new(); + let mut all_context_records: Vec = Vec::new(); + + let mut batch = Batch::new(); + + batch.init_batch(context_name, &header, column_delimiter, token_encoder); + + for report in selected_reports { + let (new_context_text, new_context) = _report_context_text(&report, &attributes); + let new_tokens = num_tokens(&new_context_text, token_encoder); + + // add the current batch to the context data and start a new batch if we are in multi-batch mode + if batch.batch_tokens + new_tokens > max_tokens { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + )?; + + if single_batch { + break; + } - let _report_context_text = |report: &CommunityReport, attributes: &[String]| -> (String, Vec) { - let mut context: Vec = vec![report.short_id.clone().unwrap_or_default(), report.title.clone()]; + batch.init_batch(context_name, &header, column_delimiter, token_encoder); + } - for field in attributes { - let value = report - .attributes - .as_ref() - .and_then(|attrs| attrs.get(field)) - .cloned() - .unwrap_or_default(); - context.push(value); + batch.batch_text.push_str(&new_context_text); + batch.batch_tokens += new_tokens; + batch.batch_records.push(new_context); } - context.push(if use_community_summary { - report.summary.clone() - } else { - report.full_content.clone() - }); + if !all_context_text.contains(&batch.batch_text) { + batch.cut_batch( + &mut all_context_text, + &mut all_context_records, + entities.clone(), + &header, + community_weight_name, + community_rank_name, + include_community_weight, + include_community_rank, + )?; + } - if include_community_rank { - context.push(report.rank.unwrap_or_default().to_string()); + if all_context_records.is_empty() { + eprintln!("Warning: No community records added when building community context."); + return Ok((Vec::new(), HashMap::new())); } - let result = context.join(column_delimiter) + "\n"; - (result, context) - }; + let records_concat = concat( + all_context_records + .into_iter() + .map(|df| df.lazy()) + .collect::>(), + UnionArgs::default(), + )? + .collect()?; - let compute_community_weights = entities.is_some() - && !community_reports.is_empty() - && include_community_weight - && (community_reports[0].attributes.is_none() - || !community_reports[0] - .attributes - .clone() - .unwrap() - .contains_key(community_weight_name)); + Ok(( + all_context_text, + HashMap::from([(context_name.to_lowercase(), records_concat)]), + )) + } - let mut community_reports = community_reports; - if compute_community_weights { - community_reports = _compute_community_weights( - community_reports, - entities.clone(), - community_weight_name, - normalize_community_weight, - ); + fn _compute_community_weights( + community_reports: Vec, + entities: Option>, + weight_attribute: &str, + normalize: bool, + ) -> Vec { + // Calculate a community's weight as the count of text units associated with entities within the community. + if let Some(entities) = entities { + let mut community_reports = community_reports.clone(); + let mut community_text_units = std::collections::HashMap::new(); + for entity in entities { + if let Some(community_ids) = entity.community_ids.clone() { + for community_id in community_ids { + community_text_units + .entry(community_id) + .or_insert_with(Vec::new) + .extend(entity.text_unit_ids.clone()); + } + } + } + for report in &mut community_reports { + if report.attributes.is_none() { + report.attributes = Some(std::collections::HashMap::new()); + } + if let Some(attributes) = &mut report.attributes { + attributes.insert( + weight_attribute.to_string(), + community_text_units + .get(&report.community_id) + .map(|text_units| text_units.len()) + .unwrap_or(0) + .to_string(), + ); + } + } + if normalize { + // Normalize by max weight + let all_weights: Vec = community_reports + .iter() + .filter_map(|report| { + report + .attributes + .as_ref() + .and_then(|attributes| attributes.get(weight_attribute)) + .map(|weight| weight.parse::().unwrap_or(0.0)) + }) + .collect(); + if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { + for mut report in community_reports { + if let Some(attributes) = &mut report.attributes { + if let Some(weight) = attributes.get_mut(weight_attribute) { + *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); + } + } + } + } + } + } + community_reports } +} - let mut selected_reports: Vec = community_reports - .iter() - .filter(|&report| _is_included(report)) - .cloned() - .collect(); +struct Batch { + batch_text: String, + batch_tokens: usize, + batch_records: Vec>, +} - if selected_reports.is_empty() { - return Ok((Vec::new(), HashMap::new())); +impl Batch { + fn new() -> Self { + Batch { + batch_text: String::new(), + batch_tokens: 0, + batch_records: Vec::new(), + } } - if shuffle_data { - let mut rng = rand::thread_rng(); - selected_reports.shuffle(&mut rng); + fn init_batch( + &mut self, + context_name: &str, + header: &Vec, + column_delimiter: &str, + token_encoder: Option, + ) { + self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); + self.batch_tokens = num_tokens(&self.batch_text, token_encoder); + self.batch_records.clear(); } - let attributes = if let Some(attributes) = &community_reports[0].attributes { - attributes.keys().cloned().collect::>() - } else { - Vec::new() - }; - - let header = _get_header(attributes); - let mut all_context_text: Vec = Vec::new(); - let mut all_context_records: Vec = Vec::new(); - - let mut batch_text = String::new(); - let mut batch_tokens = 0; - let mut batch_records: Vec> = Vec::new(); - - let mut _init_batch = || { - batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); - batch_tokens = num_tokens(&batch_text, token_encoder); - batch_records = Vec::new(); - }; - - let _cut_batch = |batch_records: Vec>, header: Vec| -> anyhow::Result<()> { - let weight_column = if include_community_weight && entities.is_some() { + fn cut_batch( + &mut self, + all_context_text: &mut Vec, + all_context_records: &mut Vec, + entities: Option>, + header: &Vec, + community_weight_name: &str, + community_rank_name: &str, + include_community_weight: bool, + include_community_rank: bool, + ) -> anyhow::Result<()> { + let weight_column = if include_community_weight && entities.is_some_and(|e| !e.is_empty()) { Some(community_weight_name) } else { None @@ -189,7 +377,12 @@ pub fn build_community_context( None }; - let mut record_df = _convert_report_context_to_df(batch_records, header, weight_column, rank_column)?; + let mut record_df = Self::_convert_report_context_to_df( + self.batch_records.clone(), + header.clone(), + weight_column, + rank_column, + )?; if record_df.is_empty() { return Ok(()); } @@ -204,123 +397,64 @@ pub fn build_community_context( all_context_records.push(record_df); Ok(()) - }; - - _init_batch(); - - Ok((vec![], HashMap::new())) -} + } -fn _compute_community_weights( - community_reports: Vec, - entities: Option>, - weight_attribute: &str, - normalize: bool, -) -> Vec { - // Calculate a community's weight as the count of text units associated with entities within the community. - if let Some(entities) = entities { - let mut community_reports = community_reports.clone(); - let mut community_text_units = std::collections::HashMap::new(); - for entity in entities { - if let Some(community_ids) = entity.community_ids.clone() { - for community_id in community_ids { - community_text_units - .entry(community_id) - .or_insert_with(Vec::new) - .extend(entity.text_unit_ids.clone()); - } - } + fn _convert_report_context_to_df( + context_records: Vec>, + header: Vec, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + if context_records.is_empty() { + return Ok(DataFrame::empty()); } - for report in &mut community_reports { - if report.attributes.is_none() { - report.attributes = Some(std::collections::HashMap::new()); - } - if let Some(attributes) = &mut report.attributes { - attributes.insert( - weight_attribute.to_string(), - community_text_units - .get(&report.community_id) - .map(|text_units| text_units.len()) - .unwrap_or(0) - .to_string(), - ); - } - } - if normalize { - // Normalize by max weight - let all_weights: Vec = community_reports - .iter() - .filter_map(|report| { - report - .attributes - .as_ref() - .and_then(|attributes| attributes.get(weight_attribute)) - .map(|weight| weight.parse::().unwrap_or(0.0)) - }) - .collect(); - if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { - for mut report in community_reports { - if let Some(attributes) = &mut report.attributes { - if let Some(weight) = attributes.get_mut(weight_attribute) { - *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); - } - } - } - } + + let mut data_series = Vec::new(); + for (header, records) in header.iter().zip(context_records.iter()) { + let series = Series::new(header, records); + data_series.push(series); } - } - community_reports -} -fn _convert_report_context_to_df( - context_records: Vec>, - header: Vec, - weight_column: Option<&str>, - rank_column: Option<&str>, -) -> anyhow::Result { - if context_records.is_empty() { - return Ok(DataFrame::empty()); - } + let record_df = DataFrame::new(data_series)?; - let mut data_series = Vec::new(); - for (header, records) in header.iter().zip(context_records.iter()) { - let series = Series::new(header, records); - data_series.push(series); + return Self::_rank_report_context(record_df, weight_column, rank_column); } - let record_df = DataFrame::new(data_series)?; + fn _rank_report_context( + report_df: DataFrame, + weight_column: Option<&str>, + rank_column: Option<&str>, + ) -> anyhow::Result { + let mut rank_attributes = Vec::new(); - return _rank_report_context(record_df, weight_column, rank_column); -} + let mut report_df = report_df; -fn _rank_report_context( - report_df: DataFrame, - weight_column: Option<&str>, - rank_column: Option<&str>, -) -> anyhow::Result { - let weight_column = weight_column.unwrap_or("occurrence weight"); - let rank_column = rank_column.unwrap_or("rank"); - - let mut rank_attributes = Vec::new(); - rank_attributes.push(weight_column); - let report_df = report_df - .clone() - .lazy() - .with_column(col(weight_column).cast(DataType::Float64)) - .collect()?; + if let Some(weight_column) = weight_column { + rank_attributes.push(weight_column); + report_df = report_df + .clone() + .lazy() + .with_column(col(weight_column).cast(DataType::Float64)) + .collect()?; + } - rank_attributes.push(rank_column); - let report_df = report_df - .clone() - .lazy() - .with_column(col(rank_column).cast(DataType::Float64)) - .collect()?; + if let Some(rank_column) = rank_column { + rank_attributes.push(rank_column); + report_df = report_df + .clone() + .lazy() + .with_column(col(rank_column).cast(DataType::Float64)) + .collect()?; + } - let report_df = report_df - .clone() - .lazy() - .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) - .collect()?; + if !rank_attributes.is_empty() { + report_df = report_df + .clone() + .lazy() + .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) + .collect()?; + } - Ok(report_df) + Ok(report_df) + } } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs index 705818a77..87db20231 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs @@ -1,18 +1,19 @@ -use async_trait::async_trait; -// use polars::prelude::*; -use std::collections::HashMap; - -// TODO: Serialize and Deserialize polars::frame::DataFrame -type DataFrame = Vec; - -#[async_trait] -pub trait GlobalContextBuilder { - /// Build the context for the global search mode. - async fn build_context( - &self, - conversation_history: Option, - context_builder_params: Option>, - ) -> (Vec, HashMap); +#[derive(Debug, Clone)] +pub struct ContextBuilderParams { + //conversation_history: Option, + pub use_community_summary: bool, + pub column_delimiter: String, + pub shuffle_data: bool, + pub include_community_rank: bool, + pub min_community_rank: u32, + pub community_rank_name: String, + pub include_community_weight: bool, + pub community_weight_name: String, + pub normalize_community_weight: bool, + pub max_tokens: usize, + pub context_name: String, + // conversation_history_user_turns_only: bool, + // conversation_history_max_turns: Option, } pub struct ConversationHistory {} diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 337831da0..1548d9671 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -117,19 +117,22 @@ pub fn read_entities( // attributes_cols: Option>, ) -> anyhow::Result> { let column_names = [ - id_col, - short_id_col.unwrap_or("short_id"), - title_col, - type_col.unwrap_or("type"), - description_col.unwrap_or("description"), - name_embedding_col.unwrap_or("name_embedding"), - description_embedding_col.unwrap_or("description_embedding"), - graph_embedding_col.unwrap_or("graph_embedding"), - community_col.unwrap_or("community_ids"), - text_unit_ids_col.unwrap_or("text_unit_ids"), - document_ids_col.unwrap_or("document_ids"), - rank_col.unwrap_or("degree"), - ]; + Some(id_col), + short_id_col, + Some(title_col), + type_col, + description_col, + name_embedding_col, + description_embedding_col, + graph_embedding_col, + community_col, + text_unit_ids_col, + document_ids_col, + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); let mut df = df.clone(); df.as_single_chunk_par(); diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index fcfca58a6..9f8b9c507 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -61,7 +61,7 @@ pub fn read_indexer_reports( } pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> anyhow::Result { - let mask = df.column("level")?.i32()?.lt_eq(community_level); + let mask = df.column("level")?.i64()?.lt_eq(community_level); let result = df.filter(&mask)?; Ok(result) @@ -121,9 +121,9 @@ pub fn read_community_reports( short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), title: row.get(1).unwrap_or(&String::new()).to_string(), community_id: row.get(0).unwrap_or(&String::new()).to_string(), - summary: row.get(3).unwrap_or(&String::new()).to_string(), - full_content: row.get(4).unwrap_or(&String::new()).to_string(), - rank: Some(row.get(5).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + summary: row.get(2).unwrap_or(&String::new()).to_string(), + full_content: row.get(3).unwrap_or(&String::new()).to_string(), + rank: Some(row.get(4).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), summary_embedding: None, full_content_embedding: None, attributes: None, diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 4c8f0f68e..0a8482144 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -25,6 +25,13 @@ pub enum MessageType { Dictionary(Vec>), } +#[derive(Debug, Clone)] +pub struct LLMParams { + pub max_tokens: u32, + pub temperature: f32, + pub response_format: HashMap, +} + #[async_trait] pub trait BaseLLM { async fn agenerate( @@ -32,7 +39,7 @@ pub trait BaseLLM { messages: MessageType, streaming: bool, callbacks: Option>, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result; } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 00cb1d9e1..247bfe098 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1,3 +1,2 @@ pub mod llm; -pub mod openai; pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs index 55636214d..3a12ee6cd 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search.rs @@ -1,30 +1,26 @@ use futures::future::join_all; -//use polars::frame::DataFrame; -use serde::{Deserialize, Serialize}; +use polars::frame::DataFrame; use serde_json::Value; use std::collections::HashMap; use std::time::Instant; use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::community_context::GlobalCommunityContext; -use crate::context_builder::context_builder::ConversationHistory; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, MessageType}; +use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; use crate::llm::utils::num_tokens; -// TODO: Serialize and Deserialize polars::frame::DataFrame -type DataFrame = Vec; - -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct SearchResult { - response: ResponseType, - context_data: ContextData, - context_text: ContextText, - completion_time: f64, - llm_calls: usize, - prompt_tokens: usize, + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ResponseType { String(String), Dictionary(HashMap), @@ -32,37 +28,36 @@ pub enum ResponseType { KeyPoints(Vec), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ContextData { String(String), DataFrames(Vec), Dictionary(HashMap), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub enum ContextText { String(String), Strings(Vec), Dictionary(HashMap), } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] pub struct KeyPoint { - answer: String, - score: i32, + pub answer: String, + pub score: i32, } -#[derive(Serialize, Deserialize)] pub struct GlobalSearchResult { - response: ResponseType, - context_data: ContextData, - context_text: ContextText, - completion_time: f64, - llm_calls: usize, - prompt_tokens: usize, - map_responses: Vec, - reduce_context_data: ContextData, - reduce_context_text: ContextText, + pub response: ResponseType, + pub context_data: ContextData, + pub context_text: ContextText, + pub completion_time: f64, + pub llm_calls: usize, + pub prompt_tokens: usize, + pub map_responses: Vec, + pub reduce_context_data: ContextData, + pub reduce_context_text: ContextText, } #[derive(Debug, Clone)] @@ -94,15 +89,15 @@ pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, token_encoder: Option, - context_builder_params: Option>, + context_builder_params: ContextBuilderParams, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, general_knowledge_inclusion_prompt: String, callbacks: Option>, max_data_tokens: usize, - map_llm_params: HashMap, - reduce_llm_params: HashMap, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, } impl GlobalSearch { @@ -117,19 +112,18 @@ impl GlobalSearch { json_mode: bool, callbacks: Option>, max_data_tokens: usize, - map_llm_params: HashMap, - reduce_llm_params: HashMap, - context_builder_params: Option>, + map_llm_params: LLMParams, + reduce_llm_params: LLMParams, + context_builder_params: ContextBuilderParams, ) -> Self { let mut map_llm_params = map_llm_params; if json_mode { - map_llm_params.insert( - "response_format".to_string(), - serde_json::json!({"type": "json_object"}), - ); + map_llm_params + .response_format + .insert("type".to_string(), "json_object".to_string()); } else { - map_llm_params.remove("response_format"); + map_llm_params.response_format.remove("response_format"); } GlobalSearch { @@ -151,14 +145,14 @@ impl GlobalSearch { pub async fn asearch( &self, query: String, - conversation_history: Option, + _conversation_history: Option, ) -> anyhow::Result { // Step 1: Generate answers for each batch of community short summaries let start_time = Instant::now(); let (context_chunks, context_records) = self .context_builder - .build_context(conversation_history, self.context_builder_params.clone()) - .await; + .build_context(self.context_builder_params.clone()) + .await?; let mut callbacks = match &self.callbacks { Some(callbacks) => { @@ -221,7 +215,7 @@ impl GlobalSearch { &self, context_data: &str, query: &str, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); let search_prompt = String::new(); @@ -282,7 +276,7 @@ impl GlobalSearch { map_responses: Vec, query: &str, callbacks: Option>, - reduce_llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); let mut key_points: Vec> = Vec::new(); @@ -384,7 +378,7 @@ impl GlobalSearch { MessageType::Dictionary(search_messages), true, llm_callbacks, - reduce_llm_params, + llm_params, ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs new file mode 100644 index 000000000..065d36246 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs @@ -0,0 +1,109 @@ +use polars::{io::SerReader, prelude::ParquetReader}; +use shinkai_graphrag::{ + context_builder::{ + community_context::GlobalCommunityContext, context_builder::ContextBuilderParams, + indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, + }, + llm::llm::LLMParams, + search::global_search::GlobalSearch, +}; +use tiktoken_rs::tokenizer::Tokenizer; + +use crate::it::utils::openai::ChatOpenAI; + +#[tokio::test] +async fn global_search_test() -> Result<(), Box> { + let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); + let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); + + let llm = ChatOpenAI::new(Some(api_key), llm_model, 20); + let token_encoder = Tokenizer::Cl100kBase; + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + let context_builder = GlobalCommunityContext::new(reports, Some(entities), Some(token_encoder)); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 12_000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new( + Box::new(llm), + context_builder, + Some(token_encoder), + String::from(""), + String::from("multiple paragraphs"), + false, + String::from(""), + true, + None, + 12_000, + map_llm_params, + reduce_llm_params, + context_builder_params, + ); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs new file mode 100644 index 000000000..d8c308735 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs @@ -0,0 +1 @@ +pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs b/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs similarity index 70% rename from shinkai-libs/shinkai-graphrag/src/llm/openai.rs rename to shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs index 644f1101f..1325ef6aa 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; - use async_openai::{ config::OpenAIConfig, - types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, CreateChatCompletionRequestArgs}, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs, ChatCompletionResponseFormat, + ChatCompletionResponseFormatType, CreateChatCompletionRequestArgs, + }, Client, }; - -use super::llm::{BaseLLMCallback, MessageType}; +use async_trait::async_trait; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; pub struct ChatOpenAI { pub api_key: Option, @@ -28,7 +29,7 @@ impl ChatOpenAI { messages: MessageType, streaming: bool, callbacks: Option>, - llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let mut retry_count = 0; @@ -54,7 +55,7 @@ impl ChatOpenAI { messages: MessageType, _streaming: bool, _callbacks: Option>, - _llm_params: HashMap, + llm_params: LLMParams, ) -> anyhow::Result { let client = match &self.api_key { Some(api_key) => Client::with_config(OpenAIConfig::new().with_api_key(api_key)), @@ -91,7 +92,24 @@ impl ChatOpenAI { .map(|m| Into::::into(m.clone())) .collect::>(); + let response_format = if llm_params + .response_format + .get_key_value("type") + .is_some_and(|(_k, v)| v == "json_object") + { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::JsonObject, + } + } else { + ChatCompletionResponseFormat { + r#type: ChatCompletionResponseFormatType::Text, + } + }; + let request = CreateChatCompletionRequestArgs::default() + .max_tokens(llm_params.max_tokens) + .temperature(llm_params.temperature) + //.response_format(response_format) .model(self.model.clone()) .messages(request_messages) .build()?; @@ -105,3 +123,16 @@ impl ChatOpenAI { return Ok(String::new()); } } + +#[async_trait] +impl BaseLLM for ChatOpenAI { + async fn agenerate( + &self, + messages: MessageType, + streaming: bool, + callbacks: Option>, + llm_params: LLMParams, + ) -> anyhow::Result { + self.agenerate(messages, streaming, callbacks, llm_params).await + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs new file mode 100644 index 000000000..4c5c9ed27 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs @@ -0,0 +1,4 @@ +mod it { + mod global_search_tests; + mod utils; +} From 2b297fb684087063c5b44e2b696eb67fe452d0bc Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 9 Aug 2024 14:05:19 +0200 Subject: [PATCH 06/12] add prompts, global search adjustments --- Cargo.lock | 58 ++++++++++++------------------------------------------ 1 file changed, 13 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 33dd7b2a7..1f5d01fd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -806,17 +806,6 @@ dependencies = [ "event-listener 2.5.3", ] -[[package]] -name = "async-lock" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" -dependencies = [ - "event-listener 5.2.0", - "event-listener-strategy", - "pin-project-lite", -] - [[package]] name = "async-openai" version = "0.23.4" @@ -925,8 +914,8 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ - "proc-macro2 1.0.84", - "quote 1.0.36", + "proc-macro2", + "quote", "syn 2.0.66", ] @@ -1860,8 +1849,8 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" dependencies = [ - "proc-macro2 1.0.84", - "quote 1.0.36", + "proc-macro2", + "quote", "syn 2.0.66", ] @@ -3322,8 +3311,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" dependencies = [ "darling 0.20.10", - "proc-macro2 1.0.84", - "quote 1.0.36", + "proc-macro2", + "quote", "syn 2.0.66", ] @@ -3675,8 +3664,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa18ce2bc66555b3218614519ac839ddb759a7d6720732f979ef8d13be147ecd" dependencies = [ "once_cell", - "proc-macro2 1.0.84", - "quote 1.0.36", + "proc-macro2", + "quote", "syn 2.0.66", ] @@ -4059,27 +4048,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "event-listener" -version = "5.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b5fb89194fa3cad959b833185b3063ba881dbfc7030680b314250779fb4cc91" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" -dependencies = [ - "event-listener 5.2.0", - "pin-project-lite", -] - [[package]] name = "eventsource-stream" version = "0.2.3" @@ -6777,8 +6745,8 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" dependencies = [ - "proc-macro2 1.0.84", - "quote 1.0.36", + "proc-macro2", + "quote", "syn 1.0.109", "target-features", ] @@ -8184,7 +8152,7 @@ dependencies = [ "memmap2 0.7.1", "num-traits", "once_cell", - "percent-encoding 2.3.0", + "percent-encoding 2.3.1", "polars-arrow", "polars-core", "polars-error", @@ -8338,7 +8306,7 @@ dependencies = [ "either", "hashbrown 0.14.5", "once_cell", - "percent-encoding 2.3.0", + "percent-encoding 2.3.1", "polars-arrow", "polars-core", "polars-io", @@ -9295,7 +9263,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ - "quote 1.0.36", + "quote", "syn 2.0.66", ] From fc5289de5e8db59074892da970f93d872d0bf03c Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 9 Aug 2024 14:05:19 +0200 Subject: [PATCH 07/12] add prompts, global search adjustments --- shinkai-libs/shinkai-graphrag/.gitignore | 1 + .../src/context_builder/community_context.rs | 27 ++- .../{ => global_search}/global_search.rs | 91 +++++++--- .../src/search/global_search/mod.rs | 2 + .../src/search/global_search/prompts.rs | 164 ++++++++++++++++++ .../tests/{it => }/global_search_tests.rs | 30 ++-- shinkai-libs/shinkai-graphrag/tests/it_mod.rs | 4 - .../tests/{it => }/utils/mod.rs | 0 .../tests/{it => }/utils/openai.rs | 2 +- 9 files changed, 270 insertions(+), 51 deletions(-) rename shinkai-libs/shinkai-graphrag/src/search/{ => global_search}/global_search.rs (81%) create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs create mode 100644 shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs rename shinkai-libs/shinkai-graphrag/tests/{it => }/global_search_tests.rs (85%) delete mode 100644 shinkai-libs/shinkai-graphrag/tests/it_mod.rs rename shinkai-libs/shinkai-graphrag/tests/{it => }/utils/mod.rs (100%) rename shinkai-libs/shinkai-graphrag/tests/{it => }/utils/openai.rs (98%) diff --git a/shinkai-libs/shinkai-graphrag/.gitignore b/shinkai-libs/shinkai-graphrag/.gitignore index 122af2cf4..74deb7343 100644 --- a/shinkai-libs/shinkai-graphrag/.gitignore +++ b/shinkai-libs/shinkai-graphrag/.gitignore @@ -1 +1,2 @@ +.vscode dataset \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index e46219021..1f6be4ffe 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -155,13 +155,13 @@ impl CommunityContext { (result, context) }; - let compute_community_weights = entities.is_some() + let compute_community_weights = entities.as_ref().is_some_and(|e| !e.is_empty()) && !community_reports.is_empty() && include_community_weight && (community_reports[0].attributes.is_none() || !community_reports[0] .attributes - .clone() + .as_ref() .unwrap() .contains_key(community_weight_name)); @@ -219,6 +219,7 @@ impl CommunityContext { community_rank_name, include_community_weight, include_community_rank, + column_delimiter, )?; if single_batch { @@ -243,6 +244,7 @@ impl CommunityContext { community_rank_name, include_community_weight, include_community_rank, + column_delimiter, )?; } @@ -365,8 +367,9 @@ impl Batch { community_rank_name: &str, include_community_weight: bool, include_community_rank: bool, + column_delimiter: &str, ) -> anyhow::Result<()> { - let weight_column = if include_community_weight && entities.is_some_and(|e| !e.is_empty()) { + let weight_column = if include_community_weight && entities.as_ref().is_some_and(|e| !e.is_empty()) { Some(community_weight_name) } else { None @@ -387,10 +390,20 @@ impl Batch { return Ok(()); } + let column_delimiter = if column_delimiter.is_empty() { + b'|' + } else { + column_delimiter.as_bytes()[0] + }; + let mut buffer = Cursor::new(Vec::new()); - CsvWriter::new(buffer.clone()).finish(&mut record_df).unwrap(); + CsvWriter::new(&mut buffer) + .include_header(true) + .with_separator(column_delimiter) + .finish(&mut record_df)?; let mut current_context_text = String::new(); + buffer.set_position(0); buffer.read_to_string(&mut current_context_text)?; all_context_text.push(current_context_text); @@ -410,7 +423,11 @@ impl Batch { } let mut data_series = Vec::new(); - for (header, records) in header.iter().zip(context_records.iter()) { + for (index, header) in header.iter().enumerate() { + let records = context_records + .iter() + .map(|r| r.get(index).unwrap_or(&String::new()).to_owned()) + .collect::>(); let series = Series::new(header, records); data_series.push(series); } diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs similarity index 81% rename from shinkai-libs/shinkai-graphrag/src/search/global_search.rs rename to shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 3a12ee6cd..b5c14795e 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -9,6 +9,9 @@ use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; use crate::llm::utils::num_tokens; +use crate::search::global_search::prompts::NO_DATA_ANSWER; + +use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; #[derive(Debug, Clone)] pub struct SearchResult { @@ -90,6 +93,7 @@ pub struct GlobalSearch { context_builder: GlobalCommunityContext, token_encoder: Option, context_builder_params: ContextBuilderParams, + map_system_prompt: String, reduce_system_prompt: String, response_type: String, allow_general_knowledge: bool, @@ -100,22 +104,42 @@ pub struct GlobalSearch { reduce_llm_params: LLMParams, } +pub struct GlobalSearchParams { + pub llm: Box, + pub context_builder: GlobalCommunityContext, + pub token_encoder: Option, + pub map_system_prompt: Option, + pub reduce_system_prompt: Option, + pub response_type: String, + pub allow_general_knowledge: bool, + pub general_knowledge_inclusion_prompt: Option, + pub json_mode: bool, + pub callbacks: Option>, + pub max_data_tokens: usize, + pub map_llm_params: LLMParams, + pub reduce_llm_params: LLMParams, + pub context_builder_params: ContextBuilderParams, +} + impl GlobalSearch { - pub fn new( - llm: Box, - context_builder: GlobalCommunityContext, - token_encoder: Option, - reduce_system_prompt: String, - response_type: String, - allow_general_knowledge: bool, - general_knowledge_inclusion_prompt: String, - json_mode: bool, - callbacks: Option>, - max_data_tokens: usize, - map_llm_params: LLMParams, - reduce_llm_params: LLMParams, - context_builder_params: ContextBuilderParams, - ) -> Self { + pub fn new(global_search_params: GlobalSearchParams) -> Self { + let GlobalSearchParams { + llm, + context_builder, + token_encoder, + map_system_prompt, + reduce_system_prompt, + response_type, + allow_general_knowledge, + general_knowledge_inclusion_prompt, + json_mode, + callbacks, + max_data_tokens, + map_llm_params, + reduce_llm_params, + context_builder_params, + } = global_search_params; + let mut map_llm_params = map_llm_params; if json_mode { @@ -126,11 +150,17 @@ impl GlobalSearch { map_llm_params.response_format.remove("response_format"); } + let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string()); + let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string()); + let general_knowledge_inclusion_prompt = + general_knowledge_inclusion_prompt.unwrap_or(GENERAL_KNOWLEDGE_INSTRUCTION.to_string()); + GlobalSearch { llm, context_builder, token_encoder, context_builder_params, + map_system_prompt, reduce_system_prompt, response_type, allow_general_knowledge, @@ -218,7 +248,8 @@ impl GlobalSearch { llm_params: LLMParams, ) -> anyhow::Result { let start_time = Instant::now(); - let search_prompt = String::new(); + let search_prompt = self.map_system_prompt.replace("{context_data}", context_data); + let mut search_messages = Vec::new(); search_messages.push(HashMap::from([ ("role".to_string(), "system".to_string()), @@ -253,6 +284,7 @@ impl GlobalSearch { if let Some(points) = points.as_array() { return points .iter() + .filter(|element| element.get("description").is_some() && element.get("score").is_some()) .map(|element| KeyPoint { answer: element .get("description") @@ -268,7 +300,10 @@ impl GlobalSearch { } } - Vec::new() + vec![KeyPoint { + answer: "".to_string(), + score: 0, + }] } async fn _reduce_response( @@ -282,15 +317,13 @@ impl GlobalSearch { let mut key_points: Vec> = Vec::new(); for (index, response) in map_responses.iter().enumerate() { - if let ResponseType::Dictionaries(response_list) = &response.response { - for element in response_list { - if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) { - let mut point = HashMap::new(); - point.insert("analyst".to_string(), (index + 1).to_string()); - point.insert("answer".to_string(), answer.to_string()); - point.insert("score".to_string(), score.to_string()); - key_points.push(point); - } + if let ResponseType::KeyPoints(response_list) = &response.response { + for key_point in response_list { + let mut point = HashMap::new(); + point.insert("analyst".to_string(), (index + 1).to_string()); + point.insert("answer".to_string(), key_point.answer.clone()); + point.insert("score".to_string(), key_point.score.to_string()); + key_points.push(point); } } } @@ -301,8 +334,10 @@ impl GlobalSearch { .collect(); if filtered_key_points.is_empty() && !self.allow_general_knowledge { + eprintln!("Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations."); + return Ok(SearchResult { - response: ResponseType::String("NO_DATA_ANSWER".to_string()), + response: ResponseType::String(NO_DATA_ANSWER.to_string()), context_data: ContextData::String("".to_string()), context_text: ContextText::String("".to_string()), completion_time: start_time.elapsed().as_secs_f64(), @@ -328,9 +363,11 @@ impl GlobalSearch { formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap())); formatted_response_data.push(point.get("answer").unwrap().to_string()); let formatted_response_text = formatted_response_data.join("\n"); + if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { break; } + data.push(formatted_response_text.clone()); total_tokens += num_tokens(&formatted_response_text, self.token_encoder); } diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs new file mode 100644 index 000000000..79f16f1e0 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs @@ -0,0 +1,2 @@ +pub mod global_search; +pub mod prompts; diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs new file mode 100644 index 000000000..7c9fef5cb --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/prompts.rs @@ -0,0 +1,164 @@ +// Copyright (c) 2024 Microsoft Corporation. +// Licensed under the MIT License + +// System prompts for global search. + +pub const MAP_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +Points supported by data should list the relevant reports as references as follows: +"This is an example sentence supported by data references [Data: Reports (report ids)]" + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. + +Do not include information where the supporting evidence for it is not provided. + +The response should be JSON formatted as follows: +{{ + "points": [ + {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, + {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + ] +}} +"#; + +pub const REDUCE_SYSTEM_PROMPT: &str = r#" +---Role--- + +You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + + +---Analyst Reports--- + +{report_data} + + +---Goal--- + +Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. + +Note that the analysts' reports provided below are ranked in the **descending order of importance**. + +If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". + +The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. + +**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. + +For example: + +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" + +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. + +Do not include information where the supporting evidence for it is not provided. + + +---Target response length and format--- + +{response_type} + +Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. +"#; + +pub const NO_DATA_ANSWER: &str = "I am sorry but I am unable to answer this question given the provided data."; + +pub const GENERAL_KNOWLEDGE_INSTRUCTION: &str = r#" +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +"#; diff --git a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs similarity index 85% rename from shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs rename to shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 065d36246..8aea91032 100644 --- a/shinkai-libs/shinkai-graphrag/tests/it/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -5,18 +5,19 @@ use shinkai_graphrag::{ indexer_entities::read_indexer_entities, indexer_reports::read_indexer_reports, }, llm::llm::LLMParams, - search::global_search::GlobalSearch, + search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; use tiktoken_rs::tokenizer::Tokenizer; +use utils::openai::ChatOpenAI; -use crate::it::utils::openai::ChatOpenAI; +mod utils; #[tokio::test] async fn global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); - let llm = ChatOpenAI::new(Some(api_key), llm_model, 20); + let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); let token_encoder = Tokenizer::Cl100kBase; // Load community reports @@ -76,21 +77,22 @@ async fn global_search_test() -> Result<(), Box> { // Perform global search - let search_engine = GlobalSearch::new( - Box::new(llm), + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), context_builder, - Some(token_encoder), - String::from(""), - String::from("multiple paragraphs"), - false, - String::from(""), - true, - None, - 12_000, + token_encoder: Some(token_encoder), + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 12_000, map_llm_params, reduce_llm_params, context_builder_params, - ); + }); let result = search_engine .asearch( diff --git a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs b/shinkai-libs/shinkai-graphrag/tests/it_mod.rs deleted file mode 100644 index 4c5c9ed27..000000000 --- a/shinkai-libs/shinkai-graphrag/tests/it_mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod it { - mod global_search_tests; - mod utils; -} diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs similarity index 100% rename from shinkai-libs/shinkai-graphrag/tests/it/utils/mod.rs rename to shinkai-libs/shinkai-graphrag/tests/utils/mod.rs diff --git a/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs similarity index 98% rename from shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs rename to shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 1325ef6aa..7eab1460e 100644 --- a/shinkai-libs/shinkai-graphrag/tests/it/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -92,7 +92,7 @@ impl ChatOpenAI { .map(|m| Into::::into(m.clone())) .collect::>(); - let response_format = if llm_params + let _response_format = if llm_params .response_format .get_key_value("type") .is_some_and(|(_k, v)| v == "json_object") From 83535cf7182d0cd0f2355295adfac2710fb718bc Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 13 Aug 2024 15:08:22 +0200 Subject: [PATCH 08/12] read indexer entities and reports improvements, compute community weights --- shinkai-libs/shinkai-graphrag/Cargo.toml | 3 +- .../src/context_builder/community_context.rs | 10 +- .../src/context_builder/indexer_entities.rs | 224 +++++++++++------- .../src/context_builder/indexer_reports.rs | 123 ++++++---- .../shinkai-graphrag/src/llm/utils.rs | 2 +- 5 files changed, 226 insertions(+), 136 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 378f20e4f..9a9eee401 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -16,5 +16,4 @@ tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } [dev-dependencies] -async-openai = "0.23.4" -tokio = { version = "1.36", features = ["full"] } \ No newline at end of file +async-openai = "0.23.4" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index 1f6be4ffe..c5b096ca1 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, io::{Cursor, Read}, }; @@ -276,7 +276,7 @@ impl CommunityContext { ) -> Vec { // Calculate a community's weight as the count of text units associated with entities within the community. if let Some(entities) = entities { - let mut community_reports = community_reports.clone(); + let mut community_reports = community_reports; let mut community_text_units = std::collections::HashMap::new(); for entity in entities { if let Some(community_ids) = entity.community_ids.clone() { @@ -297,7 +297,7 @@ impl CommunityContext { weight_attribute.to_string(), community_text_units .get(&report.community_id) - .map(|text_units| text_units.len()) + .map(|text_units| text_units.iter().flatten().cloned().collect::>().len()) .unwrap_or(0) .to_string(), ); @@ -316,7 +316,7 @@ impl CommunityContext { }) .collect(); if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) { - for mut report in community_reports { + for report in &mut community_reports { if let Some(attributes) = &mut report.attributes { if let Some(weight) = attributes.get_mut(weight_attribute) { *weight = (weight.parse::().unwrap_or(0.0) / max_weight).to_string(); @@ -325,6 +325,8 @@ impl CommunityContext { } } } + + return community_reports; } community_reports } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index 1548d9671..f550a2aef 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use polars::prelude::*; use polars_lazy::dsl::col; @@ -12,44 +12,19 @@ pub fn read_indexer_entities( community_level: u32, ) -> anyhow::Result> { let entity_df = final_nodes.clone(); - let mut entity_df = filter_under_community_level(&entity_df, community_level)?; + let entity_df = filter_under_community_level(&entity_df, community_level)?; - let entity_df = entity_df.rename("title", "name")?.rename("degree", "rank")?; + let entity_embedding_df = final_entities.clone(); let entity_df = entity_df - .clone() .lazy() + .rename(["title", "degree"], ["name", "rank"]) .with_column(col("community").fill_null(lit(-1))) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::Int32)) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("rank").cast(DataType::Int32)) - .collect()?; - - let entity_embedding_df = final_entities.clone(); - - let entity_df = entity_df - .clone() - .lazy() .group_by([col("name"), col("rank")]) .agg([col("community").max()]) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::String)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .join( entity_embedding_df.clone().lazy(), [col("name")], @@ -58,12 +33,6 @@ pub fn read_indexer_entities( ) .collect()?; - let entity_df = entity_df - .clone() - .lazy() - .filter(len().over([col("name")]).gt(lit(1))) - .collect()?; - let entities = read_entities( &entity_df, "id", @@ -134,9 +103,15 @@ pub fn read_entities( .filter_map(|&v| v.map(|v| v.to_string())) .collect::>(); + let column_names = column_names.into_iter().collect::>().into_vec(); + let mut df = df.clone(); df.as_single_chunk_par(); - let mut iters = df.columns(column_names)?.iter().map(|s| s.iter()).collect::>(); + let mut iters = df + .columns(column_names.clone())? + .iter() + .map(|s| s.iter()) + .collect::>(); let mut rows = Vec::new(); for _row in 0..df.height() { @@ -144,67 +119,144 @@ pub fn read_entities( for iter in &mut iters { let value = iter.next(); if let Some(value) = value { - row_values.push(value.to_string()); + row_values.push(value); } } rows.push(row_values); } let mut entities = Vec::new(); - for row in rows { + for (idx, row) in rows.iter().enumerate() { let report = Entity { - id: row.get(0).unwrap_or(&String::new()).to_string(), - short_id: Some(row.get(1).unwrap_or(&String::new()).to_string()), - title: row.get(2).unwrap_or(&String::new()).to_string(), - entity_type: Some(row.get(3).unwrap_or(&String::new()).to_string()), - description: Some(row.get(4).unwrap_or(&String::new()).to_string()), - name_embedding: Some( - row.get(5) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - description_embedding: Some( - row.get(6) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - graph_embedding: Some( - row.get(7) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.parse::().unwrap_or(0.0)) - .collect(), - ), - community_ids: Some( - row.get(8) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), - ), - text_unit_ids: Some( - row.get(9) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), ), - document_ids: Some( - row.get(10) - .unwrap_or(&String::new()) - .split(',') - .map(|v| v.to_string()) - .collect(), - ), - rank: Some(row.get(11).and_then(|v| v.parse::().ok()).unwrap_or(0)), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + entity_type: type_col + .map(|type_col| get_field(&row, type_col, &column_names)) + .flatten() + .map(|entity_type| entity_type.to_string()), + description: description_col + .map(|description_col| get_field(&row, description_col, &column_names)) + .flatten() + .map(|description| description.to_string()), + name_embedding: name_embedding_col.map(|name_embedding_col| { + get_field(&row, name_embedding_col, &column_names) + .map(|name_embedding| match name_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(name_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + description_embedding: description_embedding_col.map(|description_embedding_col| { + get_field(&row, description_embedding_col, &column_names) + .map(|description_embedding| match description_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + graph_embedding: graph_embedding_col.map(|graph_embedding_col| { + get_field(&row, graph_embedding_col, &column_names) + .map(|graph_embedding| match graph_embedding { + AnyValue::List(series) => series + .f64() + .unwrap_or(&ChunkedArray::from_vec(graph_embedding_col, vec![])) + .iter() + .map(|v| v.unwrap_or(0.0)) + .collect::>(), + value => vec![value.to_string().parse::().unwrap_or(0.0)], + }) + .unwrap_or_else(|| Vec::new()) + }), + community_ids: community_col.map(|community_col| { + get_field(&row, community_col, &column_names) + .map(|community_ids| match community_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| { + get_field(&row, text_unit_ids_col, &column_names) + .map(|text_unit_ids| match text_unit_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + document_ids: document_ids_col.map(|document_ids_col| { + get_field(&row, document_ids_col, &column_names) + .map(|document_ids| match document_ids { + AnyValue::List(series) => series + .str() + .unwrap_or(&StringChunked::default()) + .iter() + .map(|v| v.unwrap_or("").to_string()) + .collect::>(), + value => vec![value.to_string()], + }) + .unwrap_or_else(|| Vec::new()) + }), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0)) + }) + .flatten(), attributes: None, }; entities.push(report); } - Ok(entities) + let mut unique_entities: Vec = Vec::new(); + let mut entity_ids: HashSet = HashSet::new(); + + for entity in entities { + if !entity_ids.contains(&entity.id) { + unique_entities.push(entity.clone()); + entity_ids.insert(entity.id); + } + } + + Ok(unique_entities) +} + +pub fn get_field<'a>( + row: &'a Vec>, + column_name: &'a str, + column_names: &'a Vec, +) -> Option> { + match column_names.iter().position(|x| x == column_name) { + Some(index) => row.get(index).cloned(), + None => None, + } } diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index 9f8b9c507..cae8dc607 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -1,9 +1,11 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use polars::prelude::*; use polars_lazy::dsl::col; use serde::{Deserialize, Serialize}; +use super::indexer_entities::get_field; + pub fn read_indexer_reports( final_community_reports: &DataFrame, final_nodes: &DataFrame, @@ -12,33 +14,13 @@ pub fn read_indexer_reports( let entity_df = final_nodes.clone(); let entity_df = filter_under_community_level(&entity_df, community_level)?; - let entity_df = entity_df - .clone() + let filtered_community_df = entity_df .lazy() .with_column(col("community").fill_null(lit(-1))) - .collect()?; - let entity_df = entity_df - .clone() - .lazy() .with_column(col("community").cast(DataType::Int32)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() - .with_column(col("community").cast(DataType::String)) - .collect()?; - - let entity_df = entity_df - .clone() - .lazy() .group_by([col("title")]) .agg([col("community").max()]) - .collect()?; - - let filtered_community_df = entity_df - .clone() - .lazy() + .with_column(col("community").cast(DataType::String)) .filter(len().over([col("community")]).gt(lit(1))) .collect()?; @@ -46,7 +28,6 @@ pub fn read_indexer_reports( let report_df = filter_under_community_level(&report_df, community_level)?; let report_df = report_df - .clone() .lazy() .join( filtered_community_df.clone().lazy(), @@ -56,7 +37,18 @@ pub fn read_indexer_reports( ) .collect()?; - let reports = read_community_reports(&report_df, "community", Some("community"), None, None)?; + let reports = read_community_reports( + &report_df, + "community", + Some("community"), + "title", + "community", + "summary", + "full_content", + Some("rank"), + None, + None, + )?; Ok(reports) } @@ -83,21 +75,36 @@ pub struct CommunityReport { pub fn read_community_reports( df: &DataFrame, - _id_col: &str, - _short_id_col: Option<&str>, - // title_col: &str, - // community_col: &str, - // summary_col: &str, - // content_col: &str, - // rank_col: Option<&str>, + id_col: &str, + short_id_col: Option<&str>, + title_col: &str, + community_col: &str, + summary_col: &str, + content_col: &str, + rank_col: Option<&str>, _summary_embedding_col: Option<&str>, _content_embedding_col: Option<&str>, // attributes_cols: Option<&[&str]>, ) -> anyhow::Result> { + let column_names = [ + Some(id_col), + short_id_col, + Some(title_col), + Some(community_col), + Some(summary_col), + Some(content_col), + rank_col, + ] + .iter() + .filter_map(|&v| v.map(|v| v.to_string())) + .collect::>(); + + let column_names: Vec = column_names.into_iter().collect::>().into_vec(); + let mut df = df.clone(); df.as_single_chunk_par(); let mut iters = df - .columns(["community", "title", "summary", "full_content", "rank"])? + .columns(column_names.clone())? .iter() .map(|s| s.iter()) .collect::>(); @@ -108,22 +115,42 @@ pub fn read_community_reports( for iter in &mut iters { let value = iter.next(); if let Some(value) = value { - row_values.push(value.to_string()); + row_values.push(value); } } rows.push(row_values); } let mut reports = Vec::new(); - for row in rows { + for (idx, row) in rows.iter().enumerate() { let report = CommunityReport { - id: row.get(0).unwrap_or(&String::new()).to_string(), - short_id: Some(row.get(0).unwrap_or(&String::new()).to_string()), - title: row.get(1).unwrap_or(&String::new()).to_string(), - community_id: row.get(0).unwrap_or(&String::new()).to_string(), - summary: row.get(2).unwrap_or(&String::new()).to_string(), - full_content: row.get(3).unwrap_or(&String::new()).to_string(), - rank: Some(row.get(4).and_then(|v| v.parse::().ok()).unwrap_or(0.0)), + id: get_field(&row, id_col, &column_names) + .map(|id| id.to_string()) + .unwrap_or(String::new()), + short_id: Some( + short_id_col + .map(|short_id| get_field(&row, short_id, &column_names)) + .flatten() + .map(|short_id| short_id.to_string()) + .unwrap_or(idx.to_string()), + ), + title: get_field(&row, title_col, &column_names) + .map(|title| title.to_string()) + .unwrap_or(String::new()), + community_id: get_field(&row, community_col, &column_names) + .map(|community| community.to_string()) + .unwrap_or(String::new()), + summary: get_field(&row, summary_col, &column_names) + .map(|summary| summary.to_string()) + .unwrap_or(String::new()), + full_content: get_field(&row, content_col, &column_names) + .map(|content| content.to_string()) + .unwrap_or(String::new()), + rank: rank_col + .map(|rank_col| { + get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::().unwrap_or(0.0)) + }) + .flatten(), summary_embedding: None, full_content_embedding: None, attributes: None, @@ -131,5 +158,15 @@ pub fn read_community_reports( reports.push(report); } - Ok(reports) + let mut unique_reports: Vec = Vec::new(); + let mut report_ids: HashSet = HashSet::new(); + + for report in reports { + if !report_ids.contains(&report.id) { + unique_reports.push(report.clone()); + report_ids.insert(report.id); + } + } + + Ok(unique_reports) } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs index a6b4dfc54..1599ce78f 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs @@ -3,5 +3,5 @@ use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub fn num_tokens(text: &str, token_encoder: Option) -> usize { let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); - bpe.encode_ordinary(text).len() + bpe.encode_with_special_tokens(text).len() } From 14714f743a770a94a3bc12a07ebf19e3ebae13a1 Mon Sep 17 00:00:00 2001 From: benolt Date: Tue, 13 Aug 2024 18:36:37 +0200 Subject: [PATCH 09/12] improvements, disable global search test --- .../src/context_builder/community_context.rs | 3 --- .../src/context_builder/indexer_entities.rs | 8 ++++---- .../src/context_builder/indexer_reports.rs | 8 ++++---- .../shinkai-graphrag/tests/global_search_tests.rs | 2 +- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index c5b096ca1..d823b0070 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -451,7 +451,6 @@ impl Batch { if let Some(weight_column) = weight_column { rank_attributes.push(weight_column); report_df = report_df - .clone() .lazy() .with_column(col(weight_column).cast(DataType::Float64)) .collect()?; @@ -460,7 +459,6 @@ impl Batch { if let Some(rank_column) = rank_column { rank_attributes.push(rank_column); report_df = report_df - .clone() .lazy() .with_column(col(rank_column).cast(DataType::Float64)) .collect()?; @@ -468,7 +466,6 @@ impl Batch { if !rank_attributes.is_empty() { report_df = report_df - .clone() .lazy() .sort(rank_attributes, SortMultipleOptions::new().with_order_descending(true)) .collect()?; diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs index f550a2aef..26d8566fd 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs @@ -26,7 +26,7 @@ pub fn read_indexer_entities( .agg([col("community").max()]) .with_column(col("community").cast(DataType::String)) .join( - entity_embedding_df.clone().lazy(), + entity_embedding_df.lazy(), [col("name")], [col("name")], JoinArgs::new(JoinType::Inner), @@ -34,7 +34,7 @@ pub fn read_indexer_entities( .collect()?; let entities = read_entities( - &entity_df, + entity_df, "id", Some("human_readable_id"), "name", @@ -70,7 +70,7 @@ pub struct Entity { } pub fn read_entities( - df: &DataFrame, + df: DataFrame, id_col: &str, short_id_col: Option<&str>, title_col: &str, @@ -105,7 +105,7 @@ pub fn read_entities( let column_names = column_names.into_iter().collect::>().into_vec(); - let mut df = df.clone(); + let mut df = df; df.as_single_chunk_par(); let mut iters = df .columns(column_names.clone())? diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs index cae8dc607..1b59b9c59 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs @@ -30,7 +30,7 @@ pub fn read_indexer_reports( let report_df = report_df .lazy() .join( - filtered_community_df.clone().lazy(), + filtered_community_df.lazy(), [col("community")], [col("community")], JoinArgs::new(JoinType::Inner), @@ -38,7 +38,7 @@ pub fn read_indexer_reports( .collect()?; let reports = read_community_reports( - &report_df, + report_df, "community", Some("community"), "title", @@ -74,7 +74,7 @@ pub struct CommunityReport { } pub fn read_community_reports( - df: &DataFrame, + df: DataFrame, id_col: &str, short_id_col: Option<&str>, title_col: &str, @@ -101,7 +101,7 @@ pub fn read_community_reports( let column_names: Vec = column_names.into_iter().collect::>().into_vec(); - let mut df = df.clone(); + let mut df = df; df.as_single_chunk_par(); let mut iters = df .columns(column_names.clone())? diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 8aea91032..c08c2548a 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -12,7 +12,7 @@ use utils::openai::ChatOpenAI; mod utils; -#[tokio::test] +// #[tokio::test] async fn global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); From e79c597a2ef627cd9aa28b6aeb39d76a0e5b0492 Mon Sep 17 00:00:00 2001 From: benolt Date: Thu, 15 Aug 2024 07:01:01 +0200 Subject: [PATCH 10/12] decouple openai tokenizer --- shinkai-libs/shinkai-graphrag/Cargo.toml | 4 ++-- .../src/context_builder/community_context.rs | 23 ++++++++----------- shinkai-libs/shinkai-graphrag/src/llm/mod.rs | 1 - .../shinkai-graphrag/src/llm/utils.rs | 7 ------ .../src/search/global_search/global_search.rs | 20 +++++++--------- .../tests/global_search_tests.rs | 8 +++---- .../shinkai-graphrag/tests/utils/openai.rs | 7 ++++++ 7 files changed, 30 insertions(+), 40 deletions(-) delete mode 100644 shinkai-libs/shinkai-graphrag/src/llm/utils.rs diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 9a9eee401..18650385c 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -12,8 +12,8 @@ polars-lazy = "0.41.3" rand = "0.8.5" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.117" -tiktoken-rs = "0.5.9" tokio = { version = "1.36", features = ["full"] } [dev-dependencies] -async-openai = "0.23.4" \ No newline at end of file +async-openai = "0.23.4" +tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs index d823b0070..a6a1a7485 100644 --- a/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs +++ b/shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs @@ -10,28 +10,25 @@ use polars::{ series::Series, }; use rand::prelude::SliceRandom; -use tiktoken_rs::tokenizer::Tokenizer; - -use crate::llm::utils::num_tokens; use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport}; pub struct GlobalCommunityContext { community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, } impl GlobalCommunityContext { pub fn new( community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, ) -> Self { Self { community_reports, entities, - token_encoder, + num_tokens_fn, } } @@ -56,7 +53,7 @@ impl GlobalCommunityContext { let (community_context, community_context_data) = CommunityContext::build_community_context( self.community_reports.clone(), self.entities.clone(), - self.token_encoder.clone(), + self.num_tokens_fn, use_community_summary, &column_delimiter, shuffle_data, @@ -84,7 +81,7 @@ impl CommunityContext { pub fn build_community_context( community_reports: Vec, entities: Option>, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, use_community_summary: bool, column_delimiter: &str, shuffle_data: bool, @@ -202,11 +199,11 @@ impl CommunityContext { let mut batch = Batch::new(); - batch.init_batch(context_name, &header, column_delimiter, token_encoder); + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); for report in selected_reports { let (new_context_text, new_context) = _report_context_text(&report, &attributes); - let new_tokens = num_tokens(&new_context_text, token_encoder); + let new_tokens = num_tokens_fn(&new_context_text); // add the current batch to the context data and start a new batch if we are in multi-batch mode if batch.batch_tokens + new_tokens > max_tokens { @@ -226,7 +223,7 @@ impl CommunityContext { break; } - batch.init_batch(context_name, &header, column_delimiter, token_encoder); + batch.init_batch(context_name, &header, column_delimiter, num_tokens_fn); } batch.batch_text.push_str(&new_context_text); @@ -352,10 +349,10 @@ impl Batch { context_name: &str, header: &Vec, column_delimiter: &str, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, ) { self.batch_text = format!("-----{}-----\n{}\n", context_name, header.join(column_delimiter)); - self.batch_tokens = num_tokens(&self.batch_text, token_encoder); + self.batch_tokens = num_tokens_fn(&self.batch_text); self.batch_records.clear(); } diff --git a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs index 247bfe098..214bbef7c 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/mod.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/mod.rs @@ -1,2 +1 @@ pub mod llm; -pub mod utils; diff --git a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs b/shinkai-libs/shinkai-graphrag/src/llm/utils.rs deleted file mode 100644 index 1599ce78f..000000000 --- a/shinkai-libs/shinkai-graphrag/src/llm/utils.rs +++ /dev/null @@ -1,7 +0,0 @@ -use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; - -pub fn num_tokens(text: &str, token_encoder: Option) -> usize { - let token_encoder = token_encoder.unwrap_or_else(|| Tokenizer::Cl100kBase); - let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); - bpe.encode_with_special_tokens(text).len() -} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index b5c14795e..0368d4a59 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -3,12 +3,10 @@ use polars::frame::DataFrame; use serde_json::Value; use std::collections::HashMap; use std::time::Instant; -use tiktoken_rs::tokenizer::Tokenizer; use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; -use crate::llm::utils::num_tokens; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -26,8 +24,6 @@ pub struct SearchResult { #[derive(Debug, Clone)] pub enum ResponseType { String(String), - Dictionary(HashMap), - Dictionaries(Vec>), KeyPoints(Vec), } @@ -91,7 +87,7 @@ impl GlobalSearchLLMCallback { pub struct GlobalSearch { llm: Box, context_builder: GlobalCommunityContext, - token_encoder: Option, + num_tokens_fn: fn(&str) -> usize, context_builder_params: ContextBuilderParams, map_system_prompt: String, reduce_system_prompt: String, @@ -107,7 +103,7 @@ pub struct GlobalSearch { pub struct GlobalSearchParams { pub llm: Box, pub context_builder: GlobalCommunityContext, - pub token_encoder: Option, + pub num_tokens_fn: fn(&str) -> usize, pub map_system_prompt: Option, pub reduce_system_prompt: Option, pub response_type: String, @@ -126,7 +122,7 @@ impl GlobalSearch { let GlobalSearchParams { llm, context_builder, - token_encoder, + num_tokens_fn, map_system_prompt, reduce_system_prompt, response_type, @@ -158,7 +154,7 @@ impl GlobalSearch { GlobalSearch { llm, context_builder, - token_encoder, + num_tokens_fn, context_builder_params, map_system_prompt, reduce_system_prompt, @@ -273,7 +269,7 @@ impl GlobalSearch { context_text: ContextText::String(context_data.to_string()), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 1, - prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + prompt_tokens: (self.num_tokens_fn)(&search_prompt), }) } @@ -364,12 +360,12 @@ impl GlobalSearch { formatted_response_data.push(point.get("answer").unwrap().to_string()); let formatted_response_text = formatted_response_data.join("\n"); - if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens { + if total_tokens + (self.num_tokens_fn)(&formatted_response_text) > self.max_data_tokens { break; } data.push(formatted_response_text.clone()); - total_tokens += num_tokens(&formatted_response_text, self.token_encoder); + total_tokens += (self.num_tokens_fn)(&formatted_response_text); } let text_data = data.join("\n\n"); @@ -425,7 +421,7 @@ impl GlobalSearch { context_text: ContextText::String(text_data), completion_time: start_time.elapsed().as_secs_f64(), llm_calls: 1, - prompt_tokens: num_tokens(&search_prompt, self.token_encoder), + prompt_tokens: (self.num_tokens_fn)(&search_prompt), }) } } diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index c08c2548a..42bcd5834 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -7,8 +7,7 @@ use shinkai_graphrag::{ llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; -use tiktoken_rs::tokenizer::Tokenizer; -use utils::openai::ChatOpenAI; +use utils::openai::{num_tokens, ChatOpenAI}; mod utils; @@ -18,7 +17,6 @@ async fn global_search_test() -> Result<(), Box> { let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); let llm = ChatOpenAI::new(Some(api_key), llm_model, 5); - let token_encoder = Tokenizer::Cl100kBase; // Load community reports // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip @@ -47,7 +45,7 @@ async fn global_search_test() -> Result<(), Box> { // Build global context based on community reports - let context_builder = GlobalCommunityContext::new(reports, Some(entities), Some(token_encoder)); + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); let context_builder_params = ContextBuilderParams { use_community_summary: false, // False means using full community reports. True means using community short summaries. @@ -80,7 +78,7 @@ async fn global_search_test() -> Result<(), Box> { let search_engine = GlobalSearch::new(GlobalSearchParams { llm: Box::new(llm), context_builder, - token_encoder: Some(token_encoder), + num_tokens_fn: num_tokens, map_system_prompt: None, reduce_system_prompt: None, response_type: String::from("multiple paragraphs"), diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 7eab1460e..95e7e7b80 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -8,6 +8,7 @@ use async_openai::{ }; use async_trait::async_trait; use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { pub api_key: Option, @@ -136,3 +137,9 @@ impl BaseLLM for ChatOpenAI { self.agenerate(messages, streaming, callbacks, llm_params).await } } + +pub fn num_tokens(text: &str) -> usize { + let token_encoder = Tokenizer::Cl100kBase; + let bpe = get_bpe_from_tokenizer(token_encoder).unwrap(); + bpe.encode_with_special_tokens(text).len() +} From cf9eccccdf285ebaea94eea0b5fdb56c10d1a470 Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 16 Aug 2024 12:13:53 +0200 Subject: [PATCH 11/12] test global search with llama 3.1 --- Cargo.lock | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 1 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 6 + .../src/search/global_search/global_search.rs | 11 +- .../tests/global_search_tests.rs | 106 +++++++++++++++++- .../shinkai-graphrag/tests/utils/mod.rs | 1 + .../shinkai-graphrag/tests/utils/ollama.rs | 100 +++++++++++++++++ .../shinkai-graphrag/tests/utils/openai.rs | 3 +- 8 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index 1f5d01fd0..765736ed3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10351,6 +10351,7 @@ dependencies = [ "polars", "polars-lazy", "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "tiktoken-rs", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 18650385c..7975fa77c 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -16,4 +16,5 @@ tokio = { version = "1.36", features = ["full"] } [dev-dependencies] async-openai = "0.23.4" +reqwest = { version = "0.11.26", features = ["json"] } tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 0a8482144..5fa5cf633 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -40,6 +40,7 @@ pub trait BaseLLM { streaming: bool, callbacks: Option>, llm_params: LLMParams, + search_phase: Option, ) -> anyhow::Result; } @@ -47,3 +48,8 @@ pub trait BaseLLM { pub trait BaseTextEmbedding { async fn aembed(&self, text: &str) -> Vec; } + +pub enum GlobalSearchPhase { + Map, + Reduce, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 0368d4a59..8bb60b955 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -6,7 +6,7 @@ use std::time::Instant; use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -258,7 +258,13 @@ impl GlobalSearch { let search_response = self .llm - .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + llm_params, + Some(GlobalSearchPhase::Map), + ) .await?; let processed_response = self.parse_search_response(&search_response); @@ -412,6 +418,7 @@ impl GlobalSearch { true, llm_callbacks, llm_params, + Some(GlobalSearchPhase::Reduce), ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 42bcd5834..5c888a8df 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -7,12 +7,114 @@ use shinkai_graphrag::{ llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; -use utils::openai::{num_tokens, ChatOpenAI}; +use utils::{ + ollama::Ollama, + openai::{num_tokens, ChatOpenAI}, +}; mod utils; // #[tokio::test] -async fn global_search_test() -> Result<(), Box> { +async fn ollama_global_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let model_type = "llama3.1"; + + let llm = Ollama::new(base_url.to_string(), model_type.to_string()); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + // Using tiktoken for token count estimation + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + // LLM params are ignored for Ollama + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 5000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} + +// #[tokio::test] +async fn openai_global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs index d8c308735..3ef32f620 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs @@ -1 +1,2 @@ +pub mod ollama; pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs new file mode 100644 index 000000000..41d3619b8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -0,0 +1,100 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct OllamaMessage { + pub role: String, + pub content: String, +} + +pub struct Ollama { + base_url: String, + model_type: String, +} + +impl Ollama { + pub fn new(base_url: String, model_type: String) -> Self { + Ollama { base_url, model_type } + } +} + +#[async_trait] +impl BaseLLM for Ollama { + async fn agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + _llm_params: LLMParams, + search_phase: Option, + ) -> anyhow::Result { + let client = Client::new(); + let chat_url = format!("{}{}", &self.base_url, "/api/chat"); + + let messages_json = match messages { + MessageType::String(message) => json![message], + MessageType::Strings(messages) => json!(messages), + MessageType::Dictionary(messages) => { + let messages = match search_phase { + Some(GlobalSearchPhase::Map) => { + // Filter out system messages and convert them to user messages + messages + .into_iter() + .filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system")) + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + Some(GlobalSearchPhase::Reduce) => { + // Convert roles to user + messages + .into_iter() + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + _ => messages, + }; + + json!(messages) + } + }; + + let payload = json!({ + "model": self.model_type, + "messages": messages_json, + "stream": false, + }); + + let response = client.post(chat_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.message.content) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 95e7e7b80..255d5b4e5 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -7,7 +7,7 @@ use async_openai::{ Client, }; use async_trait::async_trait; -use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { @@ -133,6 +133,7 @@ impl BaseLLM for ChatOpenAI { streaming: bool, callbacks: Option>, llm_params: LLMParams, + _search_phase: Option, ) -> anyhow::Result { self.agenerate(messages, streaming, callbacks, llm_params).await } From fa0386ff71d91ed53e931656ab6260af17f62dcd Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 30 Aug 2024 15:12:29 +0200 Subject: [PATCH 12/12] update Cargo.lock --- Cargo.lock | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 765736ed3..6192b80b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -816,7 +816,7 @@ dependencies = [ "backoff", "base64 0.22.1", "bytes", - "derive_builder 0.20.0", + "derive_builder 0.20.1", "eventsource-stream", "futures", "rand 0.8.5", @@ -1845,9 +1845,9 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" dependencies = [ "proc-macro2", "quote", @@ -3285,11 +3285,11 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" dependencies = [ - "derive_builder_macro 0.20.0", + "derive_builder_macro 0.20.1", ] [[package]] @@ -3306,9 +3306,9 @@ dependencies = [ [[package]] name = "derive_builder_core" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" dependencies = [ "darling 0.20.10", "proc-macro2", @@ -3328,11 +3328,11 @@ dependencies = [ [[package]] name = "derive_builder_macro" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" dependencies = [ - "derive_builder_core 0.20.0", + "derive_builder_core 0.20.1", "syn 2.0.66", ] @@ -6616,9 +6616,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi 0.3.9", "libc", @@ -11671,9 +11671,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes",