diff --git a/.gitignore b/.gitignore index e747356..1517e86 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .DS_Store /.vscode +*.dump.rdb diff --git a/dump.rdb b/dump.rdb deleted file mode 100644 index 9ed305f..0000000 Binary files a/dump.rdb and /dev/null differ diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 427ca32..1de940f 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,9 +1,7 @@ use crate::RespFrame; -use dashmap::DashMap; +use dashmap::{DashMap, DashSet}; use std::ops::Deref; use std::sync::Arc; - -// region: --- Enums and Structs #[derive(Debug, Clone)] pub struct Backend(Arc); @@ -11,31 +9,7 @@ pub struct Backend(Arc); pub struct BackendInner { pub(crate) map: DashMap, pub(crate) hmap: DashMap>, -} -// endregion: --- Enums and Structs - -// region: --- impls -impl Deref for Backend { - type Target = BackendInner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Default for BackendInner { - fn default() -> Self { - Self { - map: DashMap::new(), - hmap: DashMap::new(), - } - } -} - -impl Default for Backend { - fn default() -> Self { - Self(Arc::new(BackendInner::default())) - } + pub(crate) set: DashMap>, // DashSet 中的元素要求实现 Eq, RespFrame 不能实现 Eq, 因此这里使用 String } impl Backend { @@ -65,5 +39,63 @@ impl Backend { pub fn hgetall(&self, key: &str) -> Option> { self.hmap.get(key).map(|v| v.clone()) } + + pub fn sadd(&self, key: String, members: impl Into>) -> i64 { + let set = self.set.entry(key).or_default(); + let mut cnt = 0; + for member in members.into() { + if set.insert(member) { + cnt += 1; + } + } + cnt + } + + pub fn sismember(&self, key: &str, value: &str) -> bool { + self.set + .get(key) + .and_then(|v| v.get(value).map(|_| true)) + .unwrap_or(false) + } + pub fn insert_set(&self, key: String, values: Vec) { + let set = self.set.get_mut(&key); + match set { + Some(set) => { + for value in values { + (*set).insert(value); + } + } + None => { + let new_set = DashSet::new(); + for value in values { + new_set.insert(value); + } + self.set.insert(key.to_string(), new_set); + } + } + } +} + +impl Deref for Backend { + type Target = BackendInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Default for BackendInner { + fn default() -> Self { + Self { + map: DashMap::new(), + hmap: DashMap::new(), + set: DashMap::new(), + } + } +} + +impl Default for Backend { + fn default() -> Self { + Self(Arc::new(BackendInner::default())) + } } -// endregion: --- impls diff --git a/src/cmd/command.rs b/src/cmd/command.rs new file mode 100644 index 0000000..3b08e7e --- /dev/null +++ b/src/cmd/command.rs @@ -0,0 +1,75 @@ +use enum_dispatch::enum_dispatch; +use thiserror::Error; + +use crate::{RespArray, RespError, RespFrame}; + +use super::{ + echo::Echo, + hmap::{HGet, HGetAll, HMGet, HSet}, + map::{Get, Set}, + set::{SAdd, SIsMember}, + unrecognized::Unrecognized, +}; + +#[enum_dispatch(CommandExecutor)] +#[derive(Debug)] +pub enum Command { + Get(Get), + Set(Set), + HGet(HGet), + HSet(HSet), + HGetAll(HGetAll), + Echo(Echo), + HMGet(HMGet), + SAdd(SAdd), + SIsMember(SIsMember), // S 表示 Set + // unrecognized command + Unrecognized(Unrecognized), +} + +#[derive(Error, Debug)] +pub enum CommandError { + #[error("Invalid command: {0}")] + InvalidCommand(String), + #[error("Invalid argument: {0}")] + InvalidArgument(String), + #[error("{0}")] + RespError(#[from] RespError), + #[error("Utf8 error: {0}")] + Utf8Error(#[from] std::string::FromUtf8Error), +} + +impl TryFrom for Command { + type Error = CommandError; + fn try_from(v: RespArray) -> Result { + match v.first() { + Some(RespFrame::BulkString(ref cmd)) => match cmd.as_ref() { + b"get" => Ok(Get::try_from(v)?.into()), + b"set" => Ok(Set::try_from(v)?.into()), + b"hget" => Ok(HGet::try_from(v)?.into()), + b"hset" => Ok(HSet::try_from(v)?.into()), + b"hgetall" => Ok(HGetAll::try_from(v)?.into()), + b"echo" => Ok(Echo::try_from(v)?.into()), + b"hmget" => Ok(HMGet::try_from(v)?.into()), + b"sadd" => Ok(SAdd::try_from(v)?.into()), + b"sismember" => Ok(SIsMember::try_from(v)?.into()), + _ => Ok(Unrecognized.into()), + }, + _ => Err(CommandError::InvalidCommand( + "Command must have a BulkString as the first argument".to_string(), + )), + } + } +} + +impl TryFrom for Command { + type Error = CommandError; + fn try_from(v: RespFrame) -> Result { + match v { + RespFrame::Array(array) => array.try_into(), + _ => Err(CommandError::InvalidCommand( + "Command must be an Array".to_string(), + )), + } + } +} diff --git a/src/cmd/echo.rs b/src/cmd/echo.rs new file mode 100644 index 0000000..9fb0124 --- /dev/null +++ b/src/cmd/echo.rs @@ -0,0 +1,76 @@ +use crate::{Backend, RespArray, RespFrame}; + +use super::{extract_args, validate_command, CommandError, CommandExecutor}; + +// echo: https://redis.io/docs/latest/commands/echo/ + +#[derive(Debug)] +pub struct Echo { + message: String, +} + +impl CommandExecutor for Echo { + fn execute(self, _backend: &Backend) -> RespFrame { + RespFrame::BulkString(self.message.into()) + } +} + +impl TryFrom for Echo { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["echo"], 1)?; // validate get + + let mut args = extract_args(value, 1)?.into_iter(); + match args.next() { + Some(RespFrame::BulkString(message)) => Ok(Echo { + message: String::from_utf8(message.0)?, + }), + _ => Err(CommandError::InvalidArgument("Invalid message".to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{BulkString, RespDecode}; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_echo_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*2\r\n$4\r\necho\r\n$5\r\nhello\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: Echo = frame.try_into()?; + assert_eq!(result.message, "hello"); + + Ok(()) + } + + #[test] + fn test_echo_command() -> Result<()> { + // let backend = Backend::new(); + // let cmd = Echo { + // message: "hello world".to_string(), + // }; + // let result = cmd.execute(&backend); + // assert_eq!(result, RespFrame::BulkString(b"hello world".into())); + + // Ok(()) + + let command = Echo::try_from(RespArray::new([ + BulkString::new("echo").into(), + BulkString::new("hello").into(), + ]))?; + assert_eq!(command.message, "hello"); + + let backend = Backend::new(); + let result = command.execute(&backend); + assert_eq!(result, RespFrame::BulkString(b"hello".into())); + Ok(()) + } +} diff --git a/src/cmd/hmap.rs b/src/cmd/hmap.rs deleted file mode 100644 index 1642e19..0000000 --- a/src/cmd/hmap.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::{BulkString, RespArray, RespFrame}; - -use super::{ - extract_args, validate_command, CommandError, CommandExecutor, HGet, HGetAll, HSet, RESP_OK, -}; - -// region: --- impls - -// endregion: --- impls - -impl CommandExecutor for HGet { - fn execute(self, backend: &crate::Backend) -> RespFrame { - match backend.hget(&self.key, &self.field) { - Some(value) => value, - None => RespFrame::Null(crate::RespNull), - } - } -} - -impl CommandExecutor for HGetAll { - fn execute(self, backend: &crate::Backend) -> RespFrame { - let hmap = backend.hmap.get(&self.key); - - match hmap { - Some(hmap) => { - let mut data = Vec::with_capacity(hmap.len()); - for v in hmap.iter() { - let key = v.key().to_owned(); - data.push((key, v.value().clone())); - } - if self.sort { - data.sort_by(|a, b| a.0.cmp(&b.0)); - } - let ret = data - .into_iter() - .flat_map(|(k, v)| vec![BulkString::from(k).into(), v]) - .collect::>(); - - RespArray::new(ret).into() - } - None => RespArray::new([]).into(), - } - } -} - -impl CommandExecutor for HSet { - fn execute(self, backend: &crate::Backend) -> RespFrame { - backend.hset(self.key, self.field, self.value); - RESP_OK.clone() - } -} - -impl TryFrom for HGet { - type Error = CommandError; - fn try_from(value: RespArray) -> Result { - validate_command(&value, &["hget"], 2)?; - - let mut args = extract_args(value, 1)?.into_iter(); - match (args.next(), args.next()) { - (Some(RespFrame::BulkString(key)), Some(RespFrame::BulkString(field))) => Ok(HGet { - key: String::from_utf8(key.0)?, - field: String::from_utf8(field.0)?, - }), - _ => Err(CommandError::InvalidArgument( - "Invalid key or field".to_string(), - )), - } - } -} - -impl TryFrom for HGetAll { - type Error = CommandError; - fn try_from(value: RespArray) -> Result { - validate_command(&value, &["hgetall"], 1)?; - - let mut args = extract_args(value, 1)?.into_iter(); - match args.next() { - Some(RespFrame::BulkString(key)) => Ok(HGetAll { - key: String::from_utf8(key.0)?, - sort: false, - }), - _ => Err(CommandError::InvalidArgument("Invalid key".to_string())), - } - } -} - -impl TryFrom for HSet { - type Error = CommandError; - fn try_from(value: RespArray) -> Result { - validate_command(&value, &["hset"], 3)?; - - let mut args = extract_args(value, 1)?.into_iter(); - match (args.next(), args.next(), args.next()) { - (Some(RespFrame::BulkString(key)), Some(RespFrame::BulkString(field)), Some(value)) => { - Ok(HSet { - key: String::from_utf8(key.0)?, - field: String::from_utf8(field.0)?, - value, - }) - } - _ => Err(CommandError::InvalidArgument( - "Invalid key, field or value".to_string(), - )), - } - } -} - -#[cfg(test)] -mod tests { - use crate::RespDecode; - - use super::*; - use anyhow::Result; - use bytes::BytesMut; - - #[test] - fn test_hget_from_resp_array() -> Result<()> { - let mut buf = BytesMut::new(); - buf.extend_from_slice(b"*3\r\n$4\r\nhget\r\n$3\r\nmap\r\n$5\r\nhello\r\n"); - - let frame = RespArray::decode(&mut buf)?; - - let result: HGet = frame.try_into()?; - assert_eq!(result.key, "map"); - assert_eq!(result.field, "hello"); - - Ok(()) - } - - #[test] - fn test_hgetall_from_resp_array() -> Result<()> { - let mut buf = BytesMut::new(); - buf.extend_from_slice(b"*2\r\n$7\r\nhgetall\r\n$3\r\nmap\r\n"); - - let frame = RespArray::decode(&mut buf)?; - - let result: HGetAll = frame.try_into()?; - assert_eq!(result.key, "map"); - - Ok(()) - } - - #[test] - fn test_hset_from_resp_array() -> Result<()> { - let mut buf = BytesMut::new(); - buf.extend_from_slice(b"*4\r\n$4\r\nhset\r\n$3\r\nmap\r\n$5\r\nhello\r\n$5\r\nworld\r\n"); - - let frame = RespArray::decode(&mut buf)?; - - let result: HSet = frame.try_into()?; - assert_eq!(result.key, "map"); - assert_eq!(result.field, "hello"); - assert_eq!(result.value, RespFrame::BulkString(b"world".into())); - - Ok(()) - } - - #[test] - fn test_hset_hget_hgetall_commands() -> Result<()> { - let backend = crate::Backend::new(); - let cmd = HSet { - key: "map".to_string(), - field: "hello".to_string(), - value: RespFrame::BulkString(b"world".into()), - }; - let result = cmd.execute(&backend); - assert_eq!(result, RESP_OK.clone()); - - let cmd = HSet { - key: "map".to_string(), - field: "hello1".to_string(), - value: RespFrame::BulkString(b"world1".into()), - }; - cmd.execute(&backend); - - let cmd = HGet { - key: "map".to_string(), - field: "hello".to_string(), - }; - let result = cmd.execute(&backend); - assert_eq!(result, RespFrame::BulkString(b"world".into())); - - let cmd = HGetAll { - key: "map".to_string(), - sort: true, - }; - let result = cmd.execute(&backend); - - let expected = RespArray::new([ - BulkString::from("hello").into(), - BulkString::from("world").into(), - BulkString::from("hello1").into(), - BulkString::from("world1").into(), - ]); - assert_eq!(result, expected.into()); - Ok(()) - } -} diff --git a/src/cmd/hmap/hget.rs b/src/cmd/hmap/hget.rs new file mode 100644 index 0000000..df77f5d --- /dev/null +++ b/src/cmd/hmap/hget.rs @@ -0,0 +1,59 @@ +use crate::{ + cmd::{extract_args, validate_command, CommandError, CommandExecutor}, + RespArray, RespFrame, +}; +#[derive(Debug)] +pub struct HGet { + key: String, + field: String, +} + +impl CommandExecutor for HGet { + fn execute(self, backend: &crate::Backend) -> RespFrame { + match backend.hget(&self.key, &self.field) { + Some(value) => value, + None => RespFrame::Null(crate::RespNull), + } + } +} + +impl TryFrom for HGet { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["hget"], 2)?; + + let mut args = extract_args(value, 1)?.into_iter(); + match (args.next(), args.next()) { + (Some(RespFrame::BulkString(key)), Some(RespFrame::BulkString(field))) => Ok(HGet { + key: String::from_utf8(key.0)?, + field: String::from_utf8(field.0)?, + }), + _ => Err(CommandError::InvalidArgument( + "Invalid key or field".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use crate::RespDecode; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_hget_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*3\r\n$4\r\nhget\r\n$3\r\nmap\r\n$5\r\nhello\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: HGet = frame.try_into()?; + assert_eq!(result.key, "map"); + assert_eq!(result.field, "hello"); + + Ok(()) + } +} diff --git a/src/cmd/hmap/hgetall.rs b/src/cmd/hmap/hgetall.rs new file mode 100644 index 0000000..8e5e56b --- /dev/null +++ b/src/cmd/hmap/hgetall.rs @@ -0,0 +1,74 @@ +use crate::{ + cmd::{extract_args, validate_command, CommandError, CommandExecutor}, + BulkString, RespArray, RespFrame, +}; +#[derive(Debug)] +pub struct HGetAll { + key: String, + sort: bool, // for test +} + +impl CommandExecutor for HGetAll { + fn execute(self, backend: &crate::Backend) -> RespFrame { + let hmap = backend.hmap.get(&self.key); + + match hmap { + Some(hmap) => { + let mut data = Vec::with_capacity(hmap.len()); + for v in hmap.iter() { + let key = v.key().to_owned(); + data.push((key, v.value().clone())); + } + // sort: because the order of the fields in a hash is not guaranteed + if self.sort { + data.sort_by(|a, b| a.0.cmp(&b.0)); + } + let ret = data + .into_iter() + .flat_map(|(k, v)| vec![BulkString::from(k).into(), v]) + .collect::>(); + + RespArray::new(ret).into() + } + None => RespArray::new([]).into(), + } + } +} + +impl TryFrom for HGetAll { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["hgetall"], 1)?; + + let mut args = extract_args(value, 1)?.into_iter(); + match args.next() { + Some(RespFrame::BulkString(key)) => Ok(HGetAll { + key: String::from_utf8(key.0)?, + sort: false, + }), + _ => Err(CommandError::InvalidArgument("Invalid key".to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use crate::RespDecode; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_hgetall_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*2\r\n$7\r\nhgetall\r\n$3\r\nmap\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: HGetAll = frame.try_into()?; + assert_eq!(result.key, "map"); + + Ok(()) + } +} diff --git a/src/cmd/hmap/hmget.rs b/src/cmd/hmap/hmget.rs new file mode 100644 index 0000000..8d5116f --- /dev/null +++ b/src/cmd/hmap/hmget.rs @@ -0,0 +1,98 @@ +use crate::{ + cmd::{extract_args, validate_command, CommandError, CommandExecutor}, + Backend, RespArray, RespFrame, RespNull, +}; +#[derive(Debug)] +pub struct HMGet { + key: String, + fields: Vec, +} + +impl CommandExecutor for HMGet { + fn execute(self, backend: &Backend) -> RespFrame { + let hmap = backend.hgetall(&self.key); + match hmap { + Some(hmap) => { + let mut data = Vec::with_capacity(self.fields.len()); + for field in self.fields.iter() { + let value = hmap.get(field); + match value { + Some(value) => data.push(value.clone()), + None => data.push(RespNull.into()), + } + } + RespArray::new(data).into() + } + None => RespArray::new([]).into(), + } + } +} + +impl TryFrom for HMGet { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["hmget"], usize::MAX)?; + let args = extract_args(value, 1)?.into_iter(); + let mut data = Vec::with_capacity(args.len()); + for arg in args { + match arg { + RespFrame::BulkString(s) => { + let s = String::from_utf8(s.0)?; + data.push(s); + } + _ => { + return Err(CommandError::InvalidArgument( + "Invalid key or field".to_string(), + )); + } + } + } + Ok(HMGet { + key: data.remove(0), + fields: data, + }) + } +} + +// #[cfg(test)] +// mod test { +// use super::*; +// use crate::cmd::{HSet, RESP_OK}; +// use crate::resp::BulkString; + +// #[test] +// fn test_hmget_from_resp_array() -> anyhow::Result<()> { +// let backend = Backend::new(); +// let cmd = HSet { +// key: "myhash".to_string(), +// field: "field1".to_string(), +// value: RespFrame::BulkString(b"hello".into()), +// }; +// let result = cmd.execute(&backend); +// assert_eq!(result, RESP_OK.clone()); + +// let cmd = HSet { +// key: "myhash".to_string(), +// field: "field2".to_string(), +// value: RespFrame::BulkString(b"world".into()), +// }; +// let result = cmd.execute(&backend); +// assert_eq!(result, RESP_OK.clone()); + +// let cmd = HMGet::try_from(RespArray::new(vec![ +// RespFrame::BulkString("HMGET".into()), +// RespFrame::BulkString("myhash".into()), +// RespFrame::BulkString("field1".into()), +// RespFrame::BulkString("field2".into()), +// RespFrame::BulkString("nofield".into()), +// ]))?; +// let result = cmd.execute(&backend); +// let expected = RespArray::new(vec![ +// BulkString::from("hello").into(), +// BulkString::from("world").into(), +// RespNull.into(), +// ]); +// assert_eq!(result, expected.into()); +// Ok(()) +// } +// } diff --git a/src/cmd/hmap/hset.rs b/src/cmd/hmap/hset.rs new file mode 100644 index 0000000..db0e88b --- /dev/null +++ b/src/cmd/hmap/hset.rs @@ -0,0 +1,62 @@ +use crate::{ + cmd::{extract_args, validate_command, CommandError, CommandExecutor, RESP_OK}, + RespArray, RespFrame, +}; +#[derive(Debug)] +pub struct HSet { + key: String, + field: String, + value: RespFrame, +} + +impl CommandExecutor for HSet { + fn execute(self, backend: &crate::Backend) -> RespFrame { + backend.hset(self.key, self.field, self.value); + RESP_OK.clone() + } +} + +impl TryFrom for HSet { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["hset"], 3)?; + + let mut args = extract_args(value, 1)?.into_iter(); + match (args.next(), args.next(), args.next()) { + (Some(RespFrame::BulkString(key)), Some(RespFrame::BulkString(field)), Some(value)) => { + Ok(HSet { + key: String::from_utf8(key.0)?, + field: String::from_utf8(field.0)?, + value, + }) + } + _ => Err(CommandError::InvalidArgument( + "Invalid key, field or value".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use crate::RespDecode; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_hset_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*4\r\n$4\r\nhset\r\n$3\r\nmap\r\n$5\r\nhello\r\n$5\r\nworld\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: HSet = frame.try_into()?; + assert_eq!(result.key, "map"); + assert_eq!(result.field, "hello"); + assert_eq!(result.value, RespFrame::BulkString(b"world".into())); + + Ok(()) + } +} diff --git a/src/cmd/hmap/mod.rs b/src/cmd/hmap/mod.rs new file mode 100644 index 0000000..77dcbec --- /dev/null +++ b/src/cmd/hmap/mod.rs @@ -0,0 +1,9 @@ +mod hget; +mod hgetall; +mod hmget; +mod hset; + +pub use hget::HGet; +pub use hgetall::HGetAll; +pub use hmget::HMGet; +pub use hset::HSet; diff --git a/src/cmd/map.rs b/src/cmd/map/get.rs similarity index 51% rename from src/cmd/map.rs rename to src/cmd/map/get.rs index c189d74..ee6d464 100644 --- a/src/cmd/map.rs +++ b/src/cmd/map/get.rs @@ -1,9 +1,12 @@ use crate::{ - cmd::{extract_args, validate_command, Get, Set}, + cmd::{extract_args, validate_command, CommandError, CommandExecutor}, RespArray, RespFrame, RespNull, }; -use super::{CommandError, CommandExecutor, RESP_OK}; +#[derive(Debug)] +pub struct Get { + key: String, +} impl CommandExecutor for Get { fn execute(self, backend: &crate::Backend) -> RespFrame { @@ -14,13 +17,6 @@ impl CommandExecutor for Get { } } -impl CommandExecutor for Set { - fn execute(self, backend: &crate::Backend) -> RespFrame { - backend.set(self.key, self.value); - RESP_OK.clone() - } -} - impl TryFrom for Get { type Error = CommandError; fn try_from(value: RespArray) -> Result { @@ -36,31 +32,14 @@ impl TryFrom for Get { } } -impl TryFrom for Set { - type Error = CommandError; - fn try_from(value: RespArray) -> Result { - validate_command(&value, &["set"], 2)?; - - let mut args = extract_args(value, 1)?.into_iter(); - match (args.next(), args.next()) { - (Some(RespFrame::BulkString(key)), Some(value)) => Ok(Set { - key: String::from_utf8(key.0)?, - value, - }), - _ => Err(CommandError::InvalidArgument( - "Invalid key or value".to_string(), - )), - } - } -} - #[cfg(test)] mod tests { + use anyhow::Result; + use bytes::BytesMut; + use crate::RespDecode; use super::*; - use anyhow::Result; - use bytes::BytesMut; #[test] fn test_get_from_resp_array() -> Result<()> { @@ -76,18 +55,4 @@ mod tests { Ok(()) } - - #[test] - fn test_set_from_resp_array() -> Result<()> { - let mut buf = BytesMut::new(); - buf.extend_from_slice(b"*3\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n"); - - let frame = RespArray::decode(&mut buf)?; - - let result: Set = frame.try_into()?; - assert_eq!(result.key, "hello"); - assert_eq!(result.value, RespFrame::BulkString(b"world".into())); - - Ok(()) - } } diff --git a/src/cmd/map/mod.rs b/src/cmd/map/mod.rs new file mode 100644 index 0000000..db4bd8a --- /dev/null +++ b/src/cmd/map/mod.rs @@ -0,0 +1,5 @@ +mod get; +mod set; + +pub use get::Get; +pub use set::Set; diff --git a/src/cmd/map/set.rs b/src/cmd/map/set.rs new file mode 100644 index 0000000..df0cf1f --- /dev/null +++ b/src/cmd/map/set.rs @@ -0,0 +1,58 @@ +use crate::{ + cmd::{extract_args, validate_command, CommandError, CommandExecutor, RESP_OK}, + RespArray, RespFrame, +}; + +#[derive(Debug)] +pub struct Set { + key: String, + value: RespFrame, +} + +impl CommandExecutor for Set { + fn execute(self, backend: &crate::Backend) -> RespFrame { + backend.set(self.key, self.value); + RESP_OK.clone() + } +} + +impl TryFrom for Set { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["set"], 2)?; + + let mut args = extract_args(value, 1)?.into_iter(); + match (args.next(), args.next()) { + (Some(RespFrame::BulkString(key)), Some(value)) => Ok(Set { + key: String::from_utf8(key.0)?, + value, + }), + _ => Err(CommandError::InvalidArgument( + "Invalid key or value".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use crate::RespDecode; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_set_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*3\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: Set = frame.try_into()?; + assert_eq!(result.key, "hello"); + assert_eq!(result.value, RespFrame::BulkString(b"world".into())); + + Ok(()) + } +} diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index 82d83d8..5243de3 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -1,12 +1,25 @@ use std::sync::LazyLock; +use command::CommandError; use enum_dispatch::enum_dispatch; -use thiserror::Error; -use crate::{Backend, RespArray, RespError, RespFrame, SimpleString}; +use crate::{Backend, RespArray, RespFrame, SimpleString}; +mod command; +mod echo; mod hmap; mod map; +mod set; +mod unrecognized; + +pub use { + command::Command, + echo::Echo, + hmap::{HGet, HGetAll, HMGet, HSet}, + map::{Get, Set}, + set::{SAdd, SIsMember}, + unrecognized::Unrecognized, +}; // NOTE: you could also use once_cell instead of lazy_static // lazy_static: @@ -21,119 +34,18 @@ mod map; // https://blog.rust-lang.org/2024/07/25/Rust-1.80.0.html static RESP_OK: LazyLock = LazyLock::new(|| SimpleString::new("OK").into()); -// region: --- Traits #[enum_dispatch] pub trait CommandExecutor { // fn execute(&self) -> RespFrame; fn execute(self, backend: &Backend) -> RespFrame; } -// endregion: --- Traits -// region: --- Enum and Structs -#[enum_dispatch(CommandExecutor)] -#[derive(Debug)] -pub enum Command { - Get(Get), - Set(Set), - HGet(HGet), - HSet(HSet), - HGetAll(HGetAll), - // unrecognized command - Unrecognized(Unrecognized), -} - -#[derive(Error, Debug)] -pub enum CommandError { - #[error("Invalid command: {0}")] - InvalidCommand(String), - #[error("Invalid argument: {0}")] - InvalidArgument(String), - #[error("{0}")] - RespError(#[from] RespError), - #[error("Utf8 error: {0}")] - Utf8Error(#[from] std::string::FromUtf8Error), -} - -#[derive(Debug)] -pub struct Get { - key: String, -} - -#[derive(Debug)] -pub struct Set { - key: String, - value: RespFrame, -} - -#[derive(Debug)] -pub struct HGet { - key: String, - field: String, -} - -#[derive(Debug)] -pub struct HSet { - key: String, - field: String, - value: RespFrame, -} - -#[derive(Debug)] -pub struct HGetAll { - key: String, - sort: bool, -} - -#[derive(Debug)] -pub struct Unrecognized; -// endregion: --- Enum and Structs - -// region: --- impls -impl CommandExecutor for Unrecognized { - fn execute(self, _: &Backend) -> RespFrame { - RESP_OK.clone() - } -} - -impl TryFrom for Command { - type Error = CommandError; - fn try_from(v: RespArray) -> Result { - match v.first() { - Some(RespFrame::BulkString(ref cmd)) => match cmd.as_ref() { - b"get" => Ok(Get::try_from(v)?.into()), - b"set" => Ok(Set::try_from(v)?.into()), - b"hget" => Ok(HGet::try_from(v)?.into()), - b"hset" => Ok(HSet::try_from(v)?.into()), - b"hgetall" => Ok(HGetAll::try_from(v)?.into()), - _ => Ok(Unrecognized.into()), - }, - _ => Err(CommandError::InvalidCommand( - "Command must have a BulkString as the first argument".to_string(), - )), - } - } -} - -impl TryFrom for Command { - type Error = CommandError; - fn try_from(v: RespFrame) -> Result { - match v { - RespFrame::Array(array) => array.try_into(), - _ => Err(CommandError::InvalidCommand( - "Command must be an Array".to_string(), - )), - } - } -} -// endregion: --- impls - -// region: --- functions fn validate_command( value: &RespArray, names: &[&'static str], n_args: usize, ) -> Result<(), CommandError> { - if value.len() != n_args + names.len() { + if n_args != usize::MAX && value.len() != n_args + names.len() { return Err(CommandError::InvalidArgument(format!( "{} command must have exactly {} argument", names.join(" "), @@ -165,4 +77,3 @@ fn validate_command( fn extract_args(value: RespArray, start: usize) -> Result, CommandError> { Ok(value.0.into_iter().skip(start).collect::>()) } -// endregion: --- functions diff --git a/src/cmd/set/mod.rs b/src/cmd/set/mod.rs new file mode 100644 index 0000000..2732bd3 --- /dev/null +++ b/src/cmd/set/mod.rs @@ -0,0 +1,5 @@ +mod sadd; +mod sismember; + +pub use sadd::SAdd; +pub use sismember::SIsMember; diff --git a/src/cmd/set/sadd.rs b/src/cmd/set/sadd.rs new file mode 100644 index 0000000..feb1ff5 --- /dev/null +++ b/src/cmd/set/sadd.rs @@ -0,0 +1,80 @@ +use crate::{ + cmd::{command::CommandError, extract_args, validate_command, CommandExecutor}, + RespArray, RespFrame, +}; + +#[derive(Debug)] +pub struct SAdd { + key: String, + members: Vec, +} + +impl CommandExecutor for SAdd { + fn execute(self, backend: &crate::Backend) -> RespFrame { + let (key, members) = (self.key, self.members); + let cnt = backend.sadd(key, members); + // RespFrame::Integer(cnt) + // RespFrame::BulkString(format!("{}(integer)", cnt).into()) + cnt.into() + } +} + +impl TryFrom for SAdd { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["sadd"], usize::MAX)?; + let args = extract_args(value, 1)?.into_iter(); + let mut data = Vec::with_capacity(args.len()); + for arg in args { + match arg { + RespFrame::BulkString(s) => { + let s = String::from_utf8(s.0)?; + data.push(s); + } + _ => { + return Err(CommandError::InvalidArgument( + "Invalid key or member".to_string(), + )); + } + } + } + Ok(SAdd { + key: data.remove(0), + members: data, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::RespDecode; + + use super::*; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_sadd_from_resp_array() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*4\r\n$4\r\nsadd\r\n$3\r\nset\r\n$5\r\nhello\r\n$5\r\nworld\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: SAdd = frame.try_into()?; + assert_eq!(result.key, "set"); + assert_eq!(result.members, vec!["hello", "world"]); + + Ok(()) + } + + #[test] + fn test_sadd_command() { + let backend = crate::Backend::new(); + let cmd = SAdd { + key: "set".to_string(), + members: vec!["hello".to_string(), "world".to_string()], + }; + let frame = cmd.execute(&backend); + assert_eq!(frame, RespFrame::Integer(2)); + } +} diff --git a/src/cmd/set/sismember.rs b/src/cmd/set/sismember.rs new file mode 100644 index 0000000..a108cbb --- /dev/null +++ b/src/cmd/set/sismember.rs @@ -0,0 +1,38 @@ +use crate::{ + cmd::{command::CommandError, extract_args, validate_command, CommandExecutor}, + RespArray, RespFrame, +}; + +#[derive(Debug)] +pub struct SIsMember { + key: String, + member: String, +} + +impl CommandExecutor for SIsMember { + fn execute(self, backend: &crate::Backend) -> RespFrame { + let (key, member) = (self.key, self.member); + let res = backend.sismember(&key, &member); + (res as i64).into() + } +} + +impl TryFrom for SIsMember { + type Error = CommandError; + fn try_from(value: RespArray) -> Result { + validate_command(&value, &["sismember"], 2)?; + + let mut args = extract_args(value, 1)?.into_iter(); + match (args.next(), args.next()) { + (Some(RespFrame::BulkString(key)), Some(RespFrame::BulkString(member))) => { + Ok(SIsMember { + key: String::from_utf8(key.0)?, + member: String::from_utf8(member.0)?, + }) + } + _ => Err(CommandError::InvalidArgument( + "Invalid key or member".to_string(), + )), + } + } +} diff --git a/src/cmd/unrecognized.rs b/src/cmd/unrecognized.rs new file mode 100644 index 0000000..5edc054 --- /dev/null +++ b/src/cmd/unrecognized.rs @@ -0,0 +1,40 @@ +use crate::{Backend, RespArray, RespFrame}; + +use super::{command::CommandError, CommandExecutor, RESP_OK}; + +#[derive(Debug)] +pub struct Unrecognized; + +impl CommandExecutor for Unrecognized { + fn execute(self, _: &Backend) -> RespFrame { + RESP_OK.clone() + } +} + +impl TryFrom for Unrecognized { + type Error = CommandError; + fn try_from(_value: RespArray) -> Result { + Ok(Unrecognized) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{RespArray, RespDecode}; + use anyhow::Result; + use bytes::BytesMut; + + #[test] + fn test_unrecognized_command() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*1\r\n$3\r\nfoo\r\n"); + + let frame = RespArray::decode(&mut buf)?; + + let result: Unrecognized = frame.try_into()?; + assert_eq!(result.execute(&Backend::new()), RESP_OK.clone()); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index d5776ea..1b8dbce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,8 @@ mod backend; - mod resp; pub mod cmd; pub mod network; pub use backend::*; -// pub use cmd::*; -// pub use network::*; pub use resp::*; diff --git a/src/network.rs b/src/network/mod.rs similarity index 100% rename from src/network.rs rename to src/network/mod.rs diff --git a/src/resp/array.rs b/src/resp/array.rs index 730a57f..a83280a 100644 --- a/src/resp/array.rs +++ b/src/resp/array.rs @@ -42,7 +42,7 @@ impl RespDecode for RespArray { return Err(RespError::NotComplete); } - buf.advance(end + CRLF_LEN); + buf.advance(end + CRLF_LEN); // skip the prefix and CRLF let mut frames = Vec::with_capacity(len); for _ in 0..len { diff --git a/src/resp/frame.rs b/src/resp/frame.rs index 48bceea..1a1dd72 100644 --- a/src/resp/frame.rs +++ b/src/resp/frame.rs @@ -30,7 +30,8 @@ pub enum RespFrame { Map(RespMap), Set(RespSet), } - +// NOTE: 这里需要 impl RespDecode, RespEncode 不需要是因为使用 enum_dispatch 宏的时候, 会自动实现这些 trait +// RespDecode 不能使用 enum_dispatch, 因为不支持 trait 中带有 associated type/ const 的情况 impl RespDecode for RespFrame { const PREFIX: &'static str = ""; fn decode(buf: &mut BytesMut) -> Result { @@ -136,5 +137,73 @@ impl From<&[u8; N]> for RespFrame { #[cfg(test)] mod tests { - // TODO: Add tests + use super::*; + use anyhow::Result; + + #[test] + fn test_resp_frame_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"+OK\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("OK".to_string()).into()); + + buf.extend_from_slice(b"-Error message\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, SimpleError::new("Error message".to_string()).into()); + + buf.extend_from_slice(b":1000\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, 1000i64.into()); + + buf.extend_from_slice(b"$5\r\nhello\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, BulkString::new(b"hello").into()); + + buf.extend_from_slice(b"$-1\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespNullBulkString.into()); + + buf.extend_from_slice(b"*2\r\n$4\r\necho\r\n$5\r\nhello\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!( + frame, + RespArray::new([ + BulkString::new("echo").into(), + BulkString::new("hello").into() + ]) + .into() + ); + + buf.extend_from_slice(b"*-1\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespNullArray.into()); + + buf.extend_from_slice(b"_\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespNull.into()); + + buf.extend_from_slice(b"#t\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, true.into()); + + buf.extend_from_slice(b"#f\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, false.into()); + + buf.extend_from_slice(b",1.23\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, 1.23f64.into()); + + buf.extend_from_slice(b"%2\r\n+hello\r\n$5\r\nworld\r\n+foo\r\n$3\r\nbar\r\n"); + let frame = RespMap::decode(&mut buf)?; + let mut map = RespMap::new(); + map.insert( + "hello".to_string(), + BulkString::new(b"world".to_vec()).into(), + ); + map.insert("foo".to_string(), BulkString::new(b"bar".to_vec()).into()); + assert_eq!(frame, map); + + Ok(()) + } } diff --git a/src/resp/integer.rs b/src/resp/integer.rs index 037f8c4..55bf719 100644 --- a/src/resp/integer.rs +++ b/src/resp/integer.rs @@ -3,10 +3,12 @@ use bytes::BytesMut; use super::{extract_simple_frame_data, RespDecode, RespEncode, RespError, CRLF_LEN}; // - integer: ":[<+|->]\r\n" +// NOTE: 实际测试正数不需要+号,负数需要-号 impl RespEncode for i64 { fn encode(self) -> Vec { - let sign = if self < 0 { "" } else { "+" }; // -1 => -1, 1 => +1 - format!(":{}{}\r\n", sign, self).into_bytes() + // let sign = if self < 0 { "" } else { "+" }; // -1 => -1, 1 => +1 + // format!(":{}{}\r\n", sign, self).into_bytes() + format!(":{}\r\n", self).into_bytes() } } @@ -35,8 +37,11 @@ mod tests { #[test] fn test_integer_encode() { + let frame: RespFrame = 1.into(); + assert_eq!(frame.encode(), b":1\r\n"); + let frame: RespFrame = 123.into(); - assert_eq!(frame.encode(), b":+123\r\n"); + assert_eq!(frame.encode(), b":123\r\n"); let frame: RespFrame = (-123).into(); assert_eq!(frame.encode(), b":-123\r\n"); @@ -45,7 +50,7 @@ mod tests { #[test] fn test_integer_decode() -> Result<()> { let mut buf = BytesMut::new(); - buf.extend_from_slice(b":+123\r\n"); + buf.extend_from_slice(b":123\r\n"); let frame = i64::decode(&mut buf)?; assert_eq!(frame, 123); diff --git a/src/resp/set.rs b/src/resp/set.rs index 5f0871b..c7ebc85 100644 --- a/src/resp/set.rs +++ b/src/resp/set.rs @@ -78,7 +78,7 @@ mod tests { .into(); assert_eq!( frame.encode(), - b"~2\r\n*2\r\n:+1234\r\n#t\r\n$5\r\nworld\r\n" + b"~2\r\n*2\r\n:1234\r\n#t\r\n$5\r\nworld\r\n" ); }