diff --git a/Cargo.lock b/Cargo.lock index 049d70df..6f053a17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -101,17 +101,6 @@ version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -120,9 +109,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -191,9 +180,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "block-buffer" @@ -404,13 +393,13 @@ checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "colored" -version = "2.0.0" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3616f750b84d8f0de8a58bda93e08e2a81ad3f523089b05f1dffecab48c6cbd" +checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6" dependencies = [ - "atty", + "is-terminal", "lazy_static", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -755,7 +744,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", - "rustix 0.38.1", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -932,9 +921,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "glob" @@ -982,15 +971,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.1" @@ -1151,7 +1131,7 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] @@ -1164,12 +1144,12 @@ checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24fddda5af7e54bf7da53067d6e802dbcc381d0a8eef629df528e3ebf68755cb" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.1", - "rustix 0.38.1", + "hermit-abi", + "rustix 0.38.19", "windows-sys 0.48.0", ] @@ -1229,9 +1209,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -1251,9 +1231,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "llm" @@ -1262,6 +1242,7 @@ dependencies = [ "bytesize", "clap", "llm-base", + "llm-bert", "llm-bloom", "llm-falcon", "llm-gpt2", @@ -1297,6 +1278,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "llm-bert" +version = "0.2.0-dev" +dependencies = [ + "bytemuck", + "llm-base", + "tracing", +] + [[package]] name = "llm-bloom" version = "0.2.0-dev" @@ -1374,9 +1364,9 @@ dependencies = [ [[package]] name = "llm-samplers" -version = "0.0.6" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7553f60d113c9cdc6a5402456a31cd9a273bef79f6f16d8a4f7b4bedf5f754b2" +checksum = "7e85df656cd89e7702cb56171d75aa77c7bec828af7d2054d9987c34411cf896" dependencies = [ "anyhow", "num-traits", @@ -1601,7 +1591,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", ] @@ -1613,9 +1603,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -2015,9 +2005,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.37.21" +version = "0.37.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25693a73057a1b4cb56179dd3c7ea21a7c6c5ee7d85781f5749b46f34b79c" +checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" dependencies = [ "bitflags 1.3.2", "errno", @@ -2029,14 +2019,14 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.1" +version = "0.38.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc6396159432b5c8490d4e301d8c705f61860b8b6c863bf79942ce5401968f3" +checksum = "745ecfa778e66b2b63c88a61cb36e0eea109e803b0b86bf9879fbc77c70e86ed" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.3", + "linux-raw-sys 0.4.11", "windows-sys 0.48.0", ] @@ -2256,9 +2246,9 @@ dependencies = [ [[package]] name = "spinoff" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fee259f96b31e7a18657d11741fe30d63f98e07de70e7a19d2b705ab9b331cdc" +checksum = "20aa2ed67fbb202e7b716ff8bfc6571dd9301617767380197d701c31124e88f6" dependencies = [ "colored", "once_cell", @@ -2344,7 +2334,7 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.37.21", + "rustix 0.37.27", "windows-sys 0.48.0", ] diff --git a/Cargo.toml b/Cargo.toml index ae5b22f7..045ecc9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ "crates/llm", "crates/llm-base", "crates/models/*", - "binaries/*" + "binaries/*", ] resolver = "2" default-members = ["binaries/llm-cli", "crates/llm"] @@ -27,12 +27,12 @@ anyhow = "1.0" rustyline = { version = "11.0.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } -spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] } +spinoff = { version = "0.8.0", default-features = false, features = ["dots2"] } clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing = { version = "0.1", features = ["log"] } -llm-samplers = "=0.0.6" +llm-samplers = "=0.0.7" # Config for 'cargo dist' [workspace.metadata.dist] @@ -45,7 +45,12 @@ ci = ["github"] # The installers to generate for each app installers = ["shell", "powershell"] # Target platforms to build apps for (Rust target-triple syntax) -targets = ["x86_64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-pc-windows-msvc", "aarch64-apple-darwin"] +targets = [ + "x86_64-unknown-linux-gnu", + "x86_64-apple-darwin", + "x86_64-pc-windows-msvc", + "aarch64-apple-darwin", +] # The profile that 'cargo dist' will build with [profile.dist] diff --git a/README.md b/README.md index 8a0bd7ae..b027cfdd 100644 --- a/README.md +++ b/README.md @@ -287,6 +287,7 @@ Absolutely! Please see the [contributing guide](./doc/CONTRIBUTING.md). inference API on your local machine using `llm`. - [secondbrain](https://github.com/juliooa/secondbrain): Desktop app to download and run LLMs locally in your computer using `llm`. - [floneum](https://floneum.com/): A graph editor for local AI workflows. +- [poly](https://github.com/pixelspark/poly): A versatile LLM serving back-end with tasks, streaming completion, memory retrieval, and more. #### Libraries diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 21b4a897..e158db68 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -290,6 +290,15 @@ pub struct Generate { /// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen. /// p(0.95): The cumulative probability after which no more tokens are kept for sampling. /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information. + /// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler. + /// a2(0.0): Threshold power. A reasonable value is 2. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. + /// + /// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information. + /// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler. + /// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended. #[arg(long = "sampler", short = 's', verbatim_doc_comment)] pub sampler_options: Vec, @@ -533,7 +542,7 @@ impl ModelLoad { let tokenizer_source = match self.model_and_tokenizer.to_source() { Ok(vs) => vs, Err(err) => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.fail(&format!("Failed to load tokenizer: {}", err)); } return Err(err); @@ -586,7 +595,7 @@ impl ModelLoad { file_size, tensor_count, } => { - if let Some(sp) = sp.take() { + if let Some(mut sp) = sp.take() { sp.success(&format!( "Loaded {tensor_count} tensors ({}) after {}ms", bytesize::to_string(file_size, false), @@ -601,7 +610,7 @@ impl ModelLoad { if model.is_err() { // If we've failed at loading the model, we probably haven't stopped the spinner yet. // Cancel it now if needed. - if let Some(sp) = sp { + if let Some(mut sp) = sp { sp.fail("Failed to load model") } } diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs index 4657bc9d..3ad7e486 100644 --- a/binaries/llm-cli/src/interactive.rs +++ b/binaries/llm-cli/src/interactive.rs @@ -141,7 +141,7 @@ fn feed_prompt_with_spinner( prompt.insert(0, '\n'); } - let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); + let mut sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); let result = session.feed_prompt( model, &prompt, diff --git a/binaries/llm-test/src/inference.rs b/binaries/llm-test/src/inference.rs index a9ace889..3666167e 100644 --- a/binaries/llm-test/src/inference.rs +++ b/binaries/llm-test/src/inference.rs @@ -92,14 +92,14 @@ fn run_inference( // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` // at all #[derive(Debug, Default)] -struct DeterministicSampler(SampleGreedy); +struct DeterministicSampler(SampleGreedy); -impl Sampler for DeterministicSampler { +impl Sampler for DeterministicSampler { fn sample<'a>( &mut self, - res: &mut dyn HasSamplerResources, - logits: &'a mut Logits, - ) -> anyhow::Result<&'a mut Logits> { + res: &mut dyn HasSamplerResources, + logits: &'a mut Logits, + ) -> anyhow::Result<&'a mut Logits> { let mut flat_bias = Default::default(); // This might look a little weird, but it's necessary because the resource diff --git a/crates/ggml/src/context.rs b/crates/ggml/src/context.rs index 11c35682..96f81b4f 100644 --- a/crates/ggml/src/context.rs +++ b/crates/ggml/src/context.rs @@ -266,6 +266,12 @@ impl Context { pub fn storage(&self) -> &ContextStorage { self.storage.as_ref().unwrap() } + + /// Set all values of the tensor with the specified value. + pub fn set_f32(&self, a: &Tensor, x: f32) -> Tensor { + let raw = unsafe { sys::ggml_set_f32(a.ptr.as_ptr(), x) }; + self.new_tensor_raw(raw) + } } // Operations impl Context { @@ -618,6 +624,30 @@ impl Context { }; self.new_tensor_raw(tensor) } + + /// Creates a new tensor with the square of `a` + pub fn op_sqr(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqr(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Creates a new tensor with the square-root of `a` + pub fn op_sqrt(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sqrt(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_sum(&self, a: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_sum(self.as_ptr(), a.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } + + /// Unknown + pub fn op_div(&self, a: &Tensor, b: &Tensor) -> Tensor { + let tensor = unsafe { sys::ggml_div(self.as_ptr(), a.ptr.as_ptr(), b.ptr.as_ptr()) }; + self.new_tensor_raw(tensor) + } } // Public to this crate methods impl Context { diff --git a/crates/ggml/sys/build.rs b/crates/ggml/sys/build.rs index 799f1671..ba7e876b 100644 --- a/crates/ggml/sys/build.rs +++ b/crates/ggml/sys/build.rs @@ -82,9 +82,9 @@ fn main() { if compiler.is_like_clang() || compiler.is_like_gnu() { if target_os == "macos" { build.flag("-mcpu=apple-m1"); - build.flag("-mfpu=neon"); } else if std::env::var("HOST") == std::env::var("TARGET") { build.flag("-mcpu=native"); + build.flag("-mfpu=neon"); } build.flag("-pthread"); } diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index c86ea4b0..4a8bfec1 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -429,7 +429,7 @@ impl InferenceSession { } // Remove the tokens from self.tokens. - let token_start = self.n_past - num; + let token_start = self.tokens.len() - num; let deleted_tokens: Vec<_> = self.tokens.drain(token_start..).collect(); // Remove the corresponding chars from decoded @@ -677,7 +677,7 @@ impl InferenceSession { npast: self.n_past, config: self.config, tokens: self.tokens.clone(), - logits: self.last_logits.clone(), + last_logits: self.last_logits.clone(), memory_k, memory_v, } @@ -815,7 +815,7 @@ pub struct InferenceSnapshotRef<'a> { /// All tokens generated by this inference session. pub tokens: Vec, /// The vector of logits that was produced after the last inference. - pub logits: Vec, + pub last_logits: Vec, /// The contents of the 'key' memory tensor. #[serde(with = "serde_bytes")] pub memory_k: &'a [u8], @@ -832,7 +832,7 @@ impl InferenceSnapshotRef<'_> { npast: self.npast, config: self.config, tokens: self.tokens.clone(), - last_logits: self.logits.clone(), + last_logits: self.last_logits.clone(), memory_k: self.memory_k.to_vec(), memory_v: self.memory_v.to_vec(), } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index e07c8852..f0a88a8a 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -60,7 +60,7 @@ pub struct InferenceParameters { /// This can be anything that implements [Sampler]. Refer to /// the `llm-samplers` documentation for possible samplers and suggested /// combinations: - pub sampler: Arc>>, + pub sampler: Arc>, } //Since Sampler implements Send and Sync, InferenceParameters should too. diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index f049a0cd..2e80495c 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -29,10 +29,10 @@ pub struct FileType { /// The quantization version. pub quantization_version: u32, } -impl From for u32 { +impl From for i32 { fn from(value: FileType) -> Self { - (value.quantization_version * ggml::QNT_VERSION_FACTOR) as u32 - + ggml::sys::llama::llama_ftype::from(value.format) + (value.quantization_version * ggml::QNT_VERSION_FACTOR) as i32 + + ggml::sys::llama::llama_ftype::from(value.format) as i32 } } impl TryFrom for FileType { diff --git a/crates/llm-base/src/samplers.rs b/crates/llm-base/src/samplers.rs index 7a179f0b..f0b07b9e 100644 --- a/crates/llm-base/src/samplers.rs +++ b/crates/llm-base/src/samplers.rs @@ -59,7 +59,7 @@ pub enum SamplingError { /// to ensure a valid configuration. pub struct ConfiguredSamplers { /// A builder from the `llm-samplers` crate. - pub builder: SamplerChainBuilder, + pub builder: SamplerChainBuilder, /// Mirostat 1 is present. pub mirostat1: bool, /// Mirostat 2 is present. @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers { /// We call a configuration of samplers that run in a certain order a "chain". /// Here is a description of the default chain `llm` uses: /// -/// 1. Repetition (present by default, multiple allowed) -/// 2. Frequency/Presence (optional, multiple allowed) -/// 3. Sequence Repetition (optional, multiple allowed) -/// 4. Top-K (present by default - incompatible with Mirostat) -/// 5. Tail Free (optional - incompatible with Mirostat) -/// 6. Locally Typical (optional - incompatible with Mirostat) -/// 7. Top-P (present by default - incompatible with Mirostat) -/// 8. Temperature (present by default) -/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. +/// 1. Repetition (present by default, multiple allowed) +/// 2. Frequency/Presence (optional, multiple allowed) +/// 3. Sequence Repetition (optional, multiple allowed) +/// 4. Top-K (present by default - incompatible with Mirostat) +/// 5. Tail Free (optional - incompatible with Mirostat) +/// 6. Locally Typical (optional - incompatible with Mirostat) +/// 7. Top-P (present by default - incompatible with Mirostat) +/// 8. Top-A (optional - incompatible with Mirostat) +/// 9. Min-P (optional - incompatible with Mirostat) +/// 10. Temperature (present by default) +/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution. /// /// Samplers listed as "present by default" but incompatible with Mirostat will /// only be enabled by default if there is no Mirostat sampler enabled. @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers { Option::::None, ), ), + ( + "topa", + SamplerSlot::new_single( + || Box::new(SampleTopA::default().a1(0.0).a2(0.0)), + Option::::None, + ), + ), + ( + "minp", + SamplerSlot::new_single( + || Box::new(SampleMinP::default().p(0.0)), + Option::::None, + ), + ), ( "temperature", SamplerSlot::new_single( @@ -203,7 +219,7 @@ impl ConfiguredSamplers { ))? } else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat { Err(SamplerConfigurationError::SamplerCombinationError( - "Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), + "Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(), ))? } Ok(()) @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers { .inspect(|(name, _slot)| match name.as_str() { "mirostat1" => result.mirostat1 = true, "mirostat2" => result.mirostat2 = true, - "topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true, + "topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => { + result.incompat_mirostat = true + } _ => (), }) .collect::>(); @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers { /// Sample a token. This convenience function handles building /// the sampler resources and logits objects the sampler needs. pub fn sample_token( - mut sampler: impl Sampler, + mut sampler: impl Sampler, rng: &mut impl rand::Rng, previous_tokens: &[TokenId], last_logits: impl IntoIterator, @@ -297,7 +315,7 @@ pub fn build_sampler( n_vocab: usize, bias: &[(TokenId, f32)], args: &[impl AsRef], -) -> Result>>, SamplerConfigurationError> { +) -> Result>, SamplerConfigurationError> { let mut samplers = SamplerChain::new(); if !bias.is_empty() { @@ -326,7 +344,7 @@ pub fn build_sampler( } /// Get the default sampler chain. -pub fn default_samplers() -> Arc>> { +pub fn default_samplers() -> Arc> { let mut result = ConfiguredSamplers::default(); result.ensure_default_slots(); Arc::new(Mutex::new(result.builder.into_chain())) @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> { } impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { - type TokenId = TokenId; - fn with_rng_mut( &mut self, fun: &mut dyn FnMut(&mut dyn rand::RngCore), @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> { Ok(()) } - fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> { + fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> { fun(self.previous_tokens); Ok(()) } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 0f395f5a..efff39e5 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -16,6 +16,7 @@ llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" } llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" } llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } +llm-bert = { path = "../models/bert", optional = true, version = "0.2.0-dev" } serde = { workspace = true } tracing = { workspace = true } @@ -34,13 +35,14 @@ default = ["models", "tokenizers-remote"] tokenizers-remote = ["llm-base/tokenizers-remote"] -models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] +models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt", "bert"] llama = ["dep:llm-llama"] gpt2 = ["dep:llm-gpt2"] gptj = ["dep:llm-gptj"] bloom = ["dep:llm-bloom"] gptneox = ["dep:llm-gptneox"] mpt = ["dep:llm-mpt"] +bert = ["dep:llm-bert"] # Falcon is off by default. See `llm_falcon`'s module documentation for more information. falcon = ["dep:llm-falcon"] diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index febe2441..14800686 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -172,6 +172,7 @@ macro_rules! define_models { } define_models!( + (bert, "bert", Bert, llm_bert, "Bert"), (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), diff --git a/crates/models/bert/Cargo.toml b/crates/models/bert/Cargo.toml new file mode 100644 index 00000000..0be81b40 --- /dev/null +++ b/crates/models/bert/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "llm-bert" +version = "0.2.0-dev" +license = { workspace = true } +repository = { workspace = true } +description = "An implementation of BERT for the `llm` ecosystem." +edition = "2021" +readme = "../../../README.md" + +[dependencies] +bytemuck.workspace = true +llm-base = { path = "../../llm-base", version = "0.2.0-dev" } +tracing = { version = "0.1", features = ["log"] } + diff --git a/crates/models/bert/src/lib.rs b/crates/models/bert/src/lib.rs new file mode 100644 index 00000000..9a8daf6e --- /dev/null +++ b/crates/models/bert/src/lib.rs @@ -0,0 +1,464 @@ +//! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem. +#![deny(missing_docs)] + +use std::error::Error; + +use llm_base::{ + ggml, + model::{common, HyperparametersWriteError}, + util, FileType, GraphOutputs, InferenceSession, InferenceSessionConfig, KnownModel, LoadError, + ModelContext, ModelParameters, OutputRequest, Regex, TensorLoader, TokenId, Tokenizer, +}; + +/// The BERT model. +/// +/// # Safety +/// This implements [Send] and [Sync] as it is immutable after construction. +pub struct Bert { + params: ModelParameters, + hyperparameters: Hyperparameters, + tokenizer: Tokenizer, + + word_embeddings: ggml::Tensor, + token_type_embeddings: ggml::Tensor, + position_embeddings: ggml::Tensor, + ln_e_w: ggml::Tensor, + ln_e_b: ggml::Tensor, + + // weights for the model + layers: Vec, + + // must be kept alive for the model + context: ModelContext, +} + +unsafe impl Send for Bert {} +unsafe impl Sync for Bert {} + +/// BERT [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct Hyperparameters { + /// Size of the model's vocabulary + pub n_vocab: usize, + + /// Maximum number of tokens + pub n_max_tokens: usize, + + /// Size of the model's embedding layer + pub n_embd: usize, + + /// n_head + pub n_intermediate: usize, + + /// Number of attention heads + pub n_head: usize, + + /// Number of layers in the model + pub n_layer: usize, + + /// file_type + pub file_type: FileType, +} + +impl KnownModel for Bert { + type Hyperparameters = Hyperparameters; + + fn new( + hyperparameters: Self::Hyperparameters, + params: ModelParameters, + tokenizer: Tokenizer, + tensor_loader: impl TensorLoader, + ) -> Result { + let mut tl = tensor_loader; + + let word_embeddings = tl.load("embeddings.word_embeddings.weight")?; + let token_type_embeddings = tl.load("embeddings.token_type_embeddings.weight")?; + let position_embeddings = tl.load("embeddings.position_embeddings.weight")?; + + let ln_e_w = tl.load("embeddings.LayerNorm.weight")?; + let ln_e_b = tl.load("embeddings.LayerNorm.bias")?; + + let mut layers = Vec::new(); + + for i in 0..hyperparameters.n_layer { + let backend = params.backend(i); + + let layer = Layer { + ln_att_w: tl + .load(&format!( + "encoder.layer.{i}.attention.output.LayerNorm.weight" + ))? + .transfer_to(backend), + ln_att_b: tl + .load(&format!( + "encoder.layer.{i}.attention.output.LayerNorm.bias" + ))? + .transfer_to(backend), + + // attention + q_w: tl + .load(&format!("encoder.layer.{i}.attention.self.query.weight"))? + .transfer_to(backend), + q_b: tl + .load(&format!("encoder.layer.{i}.attention.self.query.bias"))? + .transfer_to(backend), + k_w: tl + .load(&format!("encoder.layer.{i}.attention.self.key.weight"))? + .transfer_to(backend), + k_b: tl + .load(&format!("encoder.layer.{i}.attention.self.key.bias"))? + .transfer_to(backend), + v_w: tl + .load(&format!("encoder.layer.{i}.attention.self.value.weight"))? + .transfer_to(backend), + v_b: tl + .load(&format!("encoder.layer.{i}.attention.self.value.bias"))? + .transfer_to(backend), + + o_w: tl + .load(&format!("encoder.layer.{i}.attention.output.dense.weight"))? + .transfer_to(backend), + o_b: tl + .load(&format!("encoder.layer.{i}.attention.output.dense.bias"))? + .transfer_to(backend), + + // ff + ff_i_w: tl + .load(&format!("encoder.layer.{i}.intermediate.dense.weight"))? + .transfer_to(backend), + ff_i_b: tl + .load(&format!("encoder.layer.{i}.intermediate.dense.bias"))? + .transfer_to(backend), + + ln_out_w: tl + .load(&format!("encoder.layer.{i}.output.LayerNorm.weight"))? + .transfer_to(backend), + ln_out_b: tl + .load(&format!("encoder.layer.{i}.output.LayerNorm.bias"))? + .transfer_to(backend), + ff_o_w: tl + .load(&format!("encoder.layer.{i}.output.dense.weight"))? + .transfer_to(backend), + ff_o_b: tl + .load(&format!("encoder.layer.{i}.output.dense.bias"))? + .transfer_to(backend), + }; + + layers.push(layer); + } + let context = tl.finish(); + + Ok(Self { + ln_e_b, + ln_e_w, + position_embeddings, + token_type_embeddings, + word_embeddings, + hyperparameters, + params, + tokenizer, + layers, + context, + }) + } + + /// Starts a new `InferenceSession` for this model. + fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { + InferenceSession::new( + config, + &self.params, + self.hyperparameters.n_layer, + self.hyperparameters.n_embd, + self.hyperparameters.n_vocab, + ) + } + + #[tracing::instrument(level = "trace", skip_all)] + fn evaluate( + &self, + session: &mut InferenceSession, + input_tokens: &[TokenId], + output_request: &mut OutputRequest, + ) { + let input_len = input_tokens.len(); + let _ctx_size = self.params.context_size; + + let Hyperparameters { + n_vocab, + n_max_tokens: _, + n_embd, + n_intermediate: _, + n_head, + n_layer, + file_type: _, + } = self.hyperparameters; + + let d_head = n_embd / n_head; + + let outputs = session.compute(self.context.clone(), input_tokens, |builder| { + let mut ctx0 = builder.ctx0.borrow_mut(); + let gf = ctx0.create_compute_graph(); + + let embd = builder.embd; + + let mut input_layer = ctx0.op_get_rows(&self.word_embeddings, embd); + + // IL = word_embeddings + token_types + position_embeddingso + { + // token-types: a zero tensor + let mut token_types = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); + token_types.zero_data(); + + // position embeddings: another tensor + let position_buf: Vec = (0..input_len as i32).collect(); + let mut positions = ctx0.new_tensor_1d(llm_base::ElementType::I32, input_len); + unsafe { positions.write_data(bytemuck::cast_slice(&position_buf)) }; + + // IL += token_types + input_layer = ctx0.op_add( + &input_layer, + &ctx0.op_get_rows(&self.token_type_embeddings, &token_types), + ); + + // IL += position_embeddings + input_layer = ctx0.op_add( + &input_layer, + &ctx0.op_get_rows(&self.position_embeddings, &positions), + ); + } + + // embd norm + { + input_layer = ctx0.op_norm(&input_layer); + input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_e_w), &self.ln_e_b); + } + + for il in 0..n_layer { + ctx0.set_offloading(self.params.should_offload(il)); + + let mut current = input_layer.share(); + + // self-attention + { + let q_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].q_w, ¤t), + &self.layers[il].q_b, + ), + d_head, + n_head, + input_len, + ); + let q = ctx0.op_permute(&q_current, (0, 2, 1, 3)); + + let k_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].k_w, ¤t), + &self.layers[il].k_b, + ), + d_head, + n_head, + input_len, + ); + let k = ctx0.op_permute(&k_current, (0, 2, 1, 3)); + + let v_current = ctx0.op_reshape_3d( + &ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].v_w, ¤t), + &self.layers[il].v_b, + ), + d_head, + n_head, + input_len, + ); + let mut v = ctx0.op_permute(&v_current, (0, 2, 1, 3)); + + let mut kq = ctx0.op_mul_mat(&k, &q); + + // TODO: look into op_scale_inplace and op_soft_max_inplace + kq = ctx0.op_scale( + &kq, + &ctx0.new_f32(1.0 / ((n_embd as f32 / n_head as f32).sqrt())), + ); + kq = ctx0.op_soft_max(&kq); + + v = ctx0.op_cont(&ctx0.op_transpose(&v)); + + let mut kqv = ctx0.op_mul_mat(&v, &kq); + kqv = ctx0.op_permute(&kqv, (0, 2, 1, 3)); + + current = ctx0.op_cpy( + &kqv, + &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, input_len), + ); + } + + // attention output + current = ctx0.op_add( + &ctx0.op_mul_mat(&self.layers[il].o_w, ¤t), + &self.layers[il].o_b, + ); + + // re-add the layer input + current = ctx0.op_add(¤t, &input_layer); + + // attention norm + { + current = ctx0.op_norm(¤t); + current = ctx0.op_add( + &ctx0.op_mul(¤t, &self.layers[il].ln_att_w), + &self.layers[il].ln_att_b, + ); + } + + let att_output = current.share(); + + // intermediate output + current = ctx0.op_mul_mat(&self.layers[il].ff_i_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].ff_i_b); + current = ctx0.op_gelu(¤t); + + // layer output + current = ctx0.op_mul_mat(&self.layers[il].ff_o_w, ¤t); + current = ctx0.op_add(¤t, &self.layers[il].ff_o_b); + + // attentions bypass the intermediate layer + current = ctx0.op_add(&att_output, ¤t); + + // output norm + { + current = ctx0.op_norm(¤t); + current = ctx0.op_add( + &ctx0.op_mul(¤t, &self.layers[il].ln_out_w), + &self.layers[il].ln_out_b, + ); + } + + // input for next layer + input_layer = current; + } + input_layer = ctx0.op_cont(&ctx0.op_transpose(&input_layer)); + + ctx0.set_offloading(false); + // pooler + let mut sum = ctx0.new_tensor_2d(llm_base::ElementType::F32, input_len, 1); + sum = ctx0.set_f32(&sum, 1.0 / (input_len as f32)); + input_layer = ctx0.op_mul_mat(&input_layer, &sum); + + // normalizer + let length = ctx0.op_sqrt(&ctx0.op_sum(&ctx0.op_sqr(&input_layer))); + + input_layer = ctx0.op_scale(&input_layer, &ctx0.op_div(&ctx0.new_f32(1.0), &length)); + + ( + gf, + GraphOutputs { + result: input_layer.share(), + embedding_result: input_layer.share(), + output_length: input_len, + }, + ) + }); + + // finish evaluation + common::read_last_token(session, &outputs.result, n_vocab, input_len); + common::extract_logits(output_request, &outputs.result, n_vocab, input_len); + common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, 1); + } + + fn hyperparameters(&self) -> &Self::Hyperparameters { + &self.hyperparameters + } + + fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn context_size(&self) -> usize { + self.params.context_size + } + + fn bot_token_id(&self) -> Option { + self.tokenizer.id("[PAD]".as_bytes()) + } + + fn eot_token_id(&self) -> TokenId { + self.tokenizer.id("".as_bytes()).unwrap_or(2) + } + + fn quantize_tensors() -> Vec { + vec![Regex::new(".*weight").unwrap()] + } + + fn skip_quantize_tensors() -> Vec { + vec![] + } + + fn supports_rewind(&self) -> bool { + true + } +} + +impl llm_base::Hyperparameters for Hyperparameters { + fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { + Ok(Hyperparameters { + n_vocab: util::read_i32(reader)?.try_into()?, + n_max_tokens: util::read_i32(reader)?.try_into()?, + n_embd: util::read_i32(reader)?.try_into()?, + n_intermediate: util::read_i32(reader)?.try_into()?, + n_head: util::read_i32(reader)?.try_into()?, + n_layer: util::read_i32(reader)?.try_into()?, + file_type: util::read_filetype(reader)?, + }) + } + + fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { + util::write_i32(writer, self.n_vocab.try_into()?)?; + util::write_i32(writer, self.n_max_tokens.try_into()?)?; + util::write_i32(writer, self.n_embd.try_into()?)?; + util::write_i32(writer, self.n_intermediate.try_into()?)?; + util::write_i32(writer, self.n_head.try_into()?)?; + util::write_i32(writer, self.n_layer.try_into()?)?; + util::write_i32(writer, self.file_type.into())?; + Ok(()) + } + + fn n_vocabulary(&self) -> usize { + self.n_vocab + } + + fn file_type(&self) -> Option { + Some(self.file_type) + } + + fn file_type_mut(&mut self) -> Option<&mut FileType> { + Some(&mut self.file_type) + } +} + +struct Layer { + // normalization + ln_att_w: ggml::Tensor, + ln_att_b: ggml::Tensor, + + ln_out_w: ggml::Tensor, + ln_out_b: ggml::Tensor, + + // attention + q_w: ggml::Tensor, + q_b: ggml::Tensor, + k_w: ggml::Tensor, + k_b: ggml::Tensor, + v_w: ggml::Tensor, + v_b: ggml::Tensor, + + o_w: ggml::Tensor, + o_b: ggml::Tensor, + + // ff + ff_i_w: ggml::Tensor, + ff_i_b: ggml::Tensor, + + ff_o_w: ggml::Tensor, + ff_o_b: ggml::Tensor, +}