diff --git a/sea-orm-macros/src/derives/entity_model.rs b/sea-orm-macros/src/derives/entity_model.rs index 67cb27e87..3ce2801e4 100644 --- a/sea-orm-macros/src/derives/entity_model.rs +++ b/sea-orm-macros/src/derives/entity_model.rs @@ -1,10 +1,9 @@ use super::util::{escape_rust_keyword, trim_starting_raw_identifier}; use heck::{ToSnakeCase, ToUpperCamelCase}; use proc_macro2::{Ident, Span, TokenStream}; -use quote::{quote, quote_spanned}; +use quote::quote; use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, Data, Expr, Fields, Lit, - LitStr, Type, }; /// Method to derive an Model @@ -245,57 +244,12 @@ pub fn expand_derive_entity_model(data: Data, attrs: Vec) -> syn::Res } else { field_type.as_str() }; + let field_span = field.span(); + + let sea_query_col_type = crate::derives::sql_type_match::col_type_match( + sql_type, field_type, field_span, + ); - let sea_query_col_type = match sql_type { - Some(t) => quote! { sea_orm::prelude::ColumnType::#t }, - None => { - let col_type = match field_type { - "char" => quote! { Char(None) }, - "String" | "&str" => quote! { String(None) }, - "i8" => quote! { TinyInteger }, - "u8" => quote! { TinyUnsigned }, - "i16" => quote! { SmallInteger }, - "u16" => quote! { SmallUnsigned }, - "i32" => quote! { Integer }, - "u32" => quote! { Unsigned }, - "i64" => quote! { BigInteger }, - "u64" => quote! { BigUnsigned }, - "f32" => quote! { Float }, - "f64" => quote! { Double }, - "bool" => quote! { Boolean }, - "Date" | "NaiveDate" => quote! { Date }, - "Time" | "NaiveTime" => quote! { Time }, - "DateTime" | "NaiveDateTime" => { - quote! { DateTime } - } - "DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => { - quote! { TimestampWithTimeZone } - } - "Uuid" => quote! { Uuid }, - "Json" => quote! { Json }, - "Decimal" => quote! { Decimal(None) }, - "Vec" => { - quote! { Binary(sea_orm::sea_query::BlobSize::Blob(None)) } - } - _ => { - // Assumed it's ActiveEnum if none of the above type matches - quote! {} - } - }; - if col_type.is_empty() { - let field_span = field.span(); - let ty: Type = LitStr::new(field_type, field_span).parse()?; - let def = quote_spanned! { field_span => - std::convert::Into::::into( - <#ty as sea_orm::sea_query::ValueType>::column_type() - ) - }; - quote! { #def } - } else { - quote! { sea_orm::prelude::ColumnType::#col_type } - } - } - }; let col_def = quote! { sea_orm::prelude::ColumnTypeTrait::def(#sea_query_col_type) }; diff --git a/sea-orm-macros/src/derives/mod.rs b/sea-orm-macros/src/derives/mod.rs index 29f69b053..6d71d1a51 100644 --- a/sea-orm-macros/src/derives/mod.rs +++ b/sea-orm-macros/src/derives/mod.rs @@ -13,8 +13,10 @@ mod partial_model; mod primary_key; mod related_entity; mod relation; +mod sql_type_match; mod try_getable_from_json; mod util; +mod value_type; pub use active_enum::*; pub use active_model::*; @@ -31,3 +33,4 @@ pub use primary_key::*; pub use related_entity::*; pub use relation::*; pub use try_getable_from_json::*; +pub use value_type::*; diff --git a/sea-orm-macros/src/derives/sql_type_match.rs b/sea-orm-macros/src/derives/sql_type_match.rs new file mode 100644 index 000000000..65465d12a --- /dev/null +++ b/sea-orm-macros/src/derives/sql_type_match.rs @@ -0,0 +1,116 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, quote_spanned}; +use syn::{LitStr, Type}; + +pub fn col_type_match( + col_type: Option, + field_type: &str, + field_span: Span, +) -> TokenStream { + match col_type { + Some(t) => quote! { sea_orm::prelude::ColumnType::#t }, + None => { + let col_type = match field_type { + "char" => quote! { Char(None) }, + "String" | "&str" => quote! { String(None) }, + "i8" => quote! { TinyInteger }, + "u8" => quote! { TinyUnsigned }, + "i16" => quote! { SmallInteger }, + "u16" => quote! { SmallUnsigned }, + "i32" => quote! { Integer }, + "u32" => quote! { Unsigned }, + "i64" => quote! { BigInteger }, + "u64" => quote! { BigUnsigned }, + "f32" => quote! { Float }, + "f64" => quote! { Double }, + "bool" => quote! { Boolean }, + "Date" | "NaiveDate" => quote! { Date }, + "Time" | "NaiveTime" => quote! { Time }, + "DateTime" | "NaiveDateTime" => { + quote! { DateTime } + } + "DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => { + quote! { TimestampWithTimeZone } + } + "Uuid" => quote! { Uuid }, + "Json" => quote! { Json }, + "Decimal" => quote! { Decimal(None) }, + "Vec" => { + quote! { Binary(sea_orm::sea_query::BlobSize::Blob(None)) } + } + _ => { + // Assumed it's ActiveEnum if none of the above type matches + quote! {} + } + }; + if col_type.is_empty() { + let ty: Type = LitStr::new(field_type, field_span) + .parse() + .expect("field type error"); + let def = quote_spanned! { field_span => + std::convert::Into::::into( + <#ty as sea_orm::sea_query::ValueType>::column_type() + ) + }; + quote! { #def } + } else { + quote! { sea_orm::prelude::ColumnType::#col_type } + } + } + } +} + +pub fn arr_type_match( + arr_type: Option, + field_type: &str, + field_span: Span, +) -> TokenStream { + match arr_type { + Some(t) => quote! { sea_orm::sea_query::ArrayType::#t }, + None => { + let arr_type = match field_type { + "char" => quote! { Char }, + "String" | "&str" => quote! { String }, + "i8" => quote! { TinyInt }, + "u8" => quote! { TinyUnsigned }, + "i16" => quote! { SmallInt }, + "u16" => quote! { SmallUnsigned }, + "i32" => quote! { Int }, + "u32" => quote! { Unsigned }, + "i64" => quote! { BigInt }, + "u64" => quote! { BigUnsigned }, + "f32" => quote! { Float }, + "f64" => quote! { Double }, + "bool" => quote! { Bool }, + "Date" | "NaiveDate" => quote! { ChronoDate }, + "Time" | "NaiveTime" => quote! { ChronoTime }, + "DateTime" | "NaiveDateTime" => { + quote! { ChronoDateTime } + } + "DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => { + quote! { ChronoDateTimeWithTimeZone } + } + "Uuid" => quote! { Uuid }, + "Json" => quote! { Json }, + "Decimal" => quote! { Decimal }, + _ => { + // Assumed it's ActiveEnum if none of the above type matches + quote! {} + } + }; + if arr_type.is_empty() { + let ty: Type = LitStr::new(field_type, field_span) + .parse() + .expect("field type error"); + let def = quote_spanned! { field_span => + std::convert::Into::::into( + <#ty as sea_orm::sea_query::ValueType>::array_type() + ) + }; + quote! { #def } + } else { + quote! { sea_orm::sea_query::ArrayType::#arr_type } + } + } + } +} diff --git a/sea-orm-macros/src/derives/value_type.rs b/sea-orm-macros/src/derives/value_type.rs new file mode 100644 index 000000000..2a1ee606e --- /dev/null +++ b/sea-orm-macros/src/derives/value_type.rs @@ -0,0 +1,144 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{spanned::Spanned, Lit, Type}; + +struct DeriveValueType { + name: syn::Ident, + ty: Type, + column_type: TokenStream, + array_type: TokenStream, +} + +impl DeriveValueType { + pub fn new(input: syn::DeriveInput) -> Option { + let dat = input.data; + let fields: Option> = match dat { + syn::Data::Struct(syn::DataStruct { + fields: syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }), + .. + }) => Some(unnamed), + _ => None, + }; + + let field = fields + .expect("This derive accept only struct") + .first() + .expect("The struct should contain one value field") + .to_owned(); + + let name = input.ident; + let mut col_type = None; + let mut arr_type = None; + + for attr in input.attrs.iter() { + if !attr.path().is_ident("sea_orm") { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("column_type") { + let lit = meta.value()?.parse()?; + if let Lit::Str(litstr) = lit { + let ty: TokenStream = syn::parse_str(&litstr.value())?; + col_type = Some(ty); + } else { + return Err(meta.error(format!("Invalid column_type {:?}", lit))); + } + } else if meta.path.is_ident("array_type") { + let lit = meta.value()?.parse()?; + if let Lit::Str(litstr) = lit { + let ty: TokenStream = syn::parse_str(&litstr.value())?; + arr_type = Some(ty); + } else { + return Err(meta.error(format!("Invalid array_type {:?}", lit))); + } + } else { + // received other attribute + return Err(meta.error(format!("Invalid attribute {:?}", meta.path))); + } + + Ok(()) + }) + .unwrap_or(()); + } + + let ty = field.clone().ty; + let field_type = quote! { #ty } + .to_string() //E.g.: "Option < String >" + .replace(' ', ""); // Remove spaces + let field_type = if field_type.starts_with("Option<") { + &field_type[7..(field_type.len() - 1)] // Extract `T` out of `Option` + } else { + field_type.as_str() + }; + let field_span = field.span(); + + let column_type = + crate::derives::sql_type_match::col_type_match(col_type, field_type, field_span); + + let array_type = + crate::derives::sql_type_match::arr_type_match(arr_type, field_type, field_span); + + Some(DeriveValueType { + name, + ty, + column_type, + array_type, + }) + } + + fn expand(&self) -> syn::Result { + let expanded_impl_value_type: TokenStream = self.impl_value_type(); + Ok(expanded_impl_value_type) + } + + fn impl_value_type(&self) -> TokenStream { + let name = &self.name; + let field_type = &self.ty; + let column_type = &self.column_type; + let array_type = &self.array_type; + + quote!( + #[automatically_derived] + impl From<#name> for Value { + fn from(source: #name) -> Self { + source.0.into() + } + } + + #[automatically_derived] + impl sea_orm::TryGetable for #name { + fn try_get_by(res: &QueryResult, idx: I) -> Result { + <#field_type as sea_orm::TryGetable>::try_get_by(res, idx).map(|v| #name(v)) + } + } + + #[automatically_derived] + impl sea_orm::sea_query::ValueType for #name { + fn try_from(v: Value) -> Result { + <#field_type as sea_orm::sea_query::ValueType>::try_from(v).map(|v| #name(v)) + } + + fn type_name() -> String { + stringify!(#name).to_owned() + } + + fn array_type() -> sea_orm::sea_query::ArrayType { + #array_type + } + + fn column_type() -> sea_orm::sea_query::ColumnType { + #column_type + } + } + ) + } +} + +pub fn expand_derive_value_type(input: syn::DeriveInput) -> syn::Result { + let input_span = input.span(); + match DeriveValueType::new(input) { + Some(model) => model.expand(), + None => Err(syn::Error::new(input_span, "error")), + } +} diff --git a/sea-orm-macros/src/lib.rs b/sea-orm-macros/src/lib.rs index 0e9c7b1b8..07a26f053 100644 --- a/sea-orm-macros/src/lib.rs +++ b/sea-orm-macros/src/lib.rs @@ -832,3 +832,13 @@ pub fn enum_iter(input: TokenStream) -> TokenStream { .unwrap_or_else(Error::into_compile_error) .into() } + +#[cfg(feature = "derive")] +#[proc_macro_derive(DeriveValueType, attributes(sea_orm))] +pub fn derive_value_type(input: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(input as DeriveInput); + match derives::expand_derive_value_type(derive_input) { + Ok(token_stream) => token_stream.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/tests/common/features/mod.rs b/tests/common/features/mod.rs index 8507b4e38..407a59917 100644 --- a/tests/common/features/mod.rs +++ b/tests/common/features/mod.rs @@ -23,6 +23,7 @@ pub mod self_join; pub mod teas; pub mod transaction_log; pub mod uuid_fmt; +pub mod value_type; pub use active_enum::Entity as ActiveEnum; pub use active_enum_child::Entity as ActiveEnumChild; diff --git a/tests/common/features/schema.rs b/tests/common/features/schema.rs index 12897ca89..058474e0e 100644 --- a/tests/common/features/schema.rs +++ b/tests/common/features/schema.rs @@ -48,8 +48,10 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> { create_binary_table(db).await?; create_bits_table(db).await?; create_dyn_table_name_lazy_static_table(db).await?; + create_value_type_table(db).await?; if DbBackend::Postgres == db_backend { + create_value_type_postgres_table(db).await?; create_collection_table(db).await?; create_event_trigger_table(db).await?; } @@ -634,3 +636,47 @@ pub async fn create_dyn_table_name_lazy_static_table(db: &DbConn) -> Result<(), Ok(()) } + +pub async fn create_value_type_table(db: &DbConn) -> Result { + let general_stmt = sea_query::Table::create() + .table(value_type::value_type_general::Entity) + .col( + ColumnDef::new(value_type::value_type_general::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(value_type::value_type_general::Column::Number) + .integer() + .not_null(), + ) + .to_owned(); + + create_table(db, &general_stmt, value_type::value_type_general::Entity).await +} +pub async fn create_value_type_postgres_table(db: &DbConn) -> Result { + let postgres_stmt = sea_query::Table::create() + .table(value_type::value_type_pg::Entity) + .col( + ColumnDef::new(value_type::value_type_pg::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(value_type::value_type_pg::Column::Number) + .integer() + .not_null(), + ) + .col( + ColumnDef::new(json_vec::Column::StrVec) + .array(sea_query::ColumnType::String(None)) + .not_null(), + ) + .to_owned(); + + create_table(db, &postgres_stmt, value_type::value_type_pg::Entity).await +} diff --git a/tests/common/features/value_type.rs b/tests/common/features/value_type.rs new file mode 100644 index 000000000..f0efb9dc2 --- /dev/null +++ b/tests/common/features/value_type.rs @@ -0,0 +1,59 @@ +pub mod value_type_general { + use super::*; + use sea_orm::entity::prelude::*; + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "value_type")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub number: Integer, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + + impl ActiveModelBehavior for ActiveModel {} +} + +pub mod value_type_pg { + use super::*; + use sea_orm::entity::prelude::*; + + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] + #[sea_orm(table_name = "value_type_postgres")] + pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub number: Integer, + pub str_vec: StringVec, + } + + #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] + pub enum Relation {} + + impl ActiveModelBehavior for ActiveModel {} +} + +use sea_orm::entity::prelude::*; +use sea_orm_macros::DeriveValueType; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +#[sea_orm(array_type = "Int")] +pub struct Integer(i32); + +impl From for Integer +where + T: Into, +{ + fn from(v: T) -> Integer { + Integer(v.into()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +#[sea_orm(column_type = "Boolean", array_type = "Bool")] +pub struct Boolbean(pub String); + +#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)] +pub struct StringVec(pub Vec); diff --git a/tests/value_type_tests.rs b/tests/value_type_tests.rs new file mode 100644 index 000000000..227186606 --- /dev/null +++ b/tests/value_type_tests.rs @@ -0,0 +1,110 @@ +pub mod common; + +use std::sync::Arc; +use std::vec; + +pub use common::{ + features::{ + value_type::{value_type_general, value_type_pg, Boolbean, Integer, StringVec}, + *, + }, + setup::*, + TestContext, +}; +use pretty_assertions::assert_eq; +use sea_orm::{entity::prelude::*, entity::*, DatabaseConnection}; +use sea_query::{ArrayType, ColumnType, Value, ValueType, ValueTypeErr}; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +async fn main() -> Result<(), DbErr> { + let ctx = TestContext::new("value_type_tests").await; + create_tables(&ctx.db).await?; + insert_value(&ctx.db).await?; + ctx.delete().await; + + if cfg!(feature = "sqlx-postgres") { + let ctx = TestContext::new("value_type_postgres_tests").await; + create_tables(&ctx.db).await?; + postgres_insert_value(&ctx.db).await?; + ctx.delete().await; + } + + type_test(); + conversion_test(); + + Ok(()) +} + +pub async fn insert_value(db: &DatabaseConnection) -> Result<(), DbErr> { + let model = value_type_general::Model { + id: 1, + number: 48.into(), + }; + let result = model.clone().into_active_model().insert(db).await?; + assert_eq!(result, model); + + Ok(()) +} + +pub async fn postgres_insert_value(db: &DatabaseConnection) -> Result<(), DbErr> { + let model = value_type_pg::Model { + id: 1, + number: 48.into(), + str_vec: StringVec(vec!["ab".to_string(), "cd".to_string()]), + }; + let result = model.clone().into_active_model().insert(db).await?; + assert_eq!(result, model); + + Ok(()) +} + +pub fn type_test() { + assert_eq!(StringVec::type_name(), "StringVec"); + + // custom types + assert_eq!(Integer::array_type(), ArrayType::Int); + assert_eq!(Integer::array_type(), ArrayType::Int); + assert_eq!(Boolbean::column_type(), ColumnType::Boolean); + assert_eq!(Boolbean::array_type(), ArrayType::Bool); + // self implied + assert_eq!( + StringVec::column_type(), + ColumnType::Array(Arc::new(ColumnType::String(None))) + ); + assert_eq!(StringVec::array_type(), ArrayType::String); +} + +pub fn conversion_test() { + let stringvec = StringVec(vec!["ab".to_string(), "cd".to_string()]); + let string: Value = stringvec.into(); + assert_eq!( + string, + Value::Array( + ArrayType::String, + Some(Box::new(vec![ + "ab".to_string().into(), + "cd".to_string().into() + ])) + ) + ); + + let value_random_int = Value::Int(Some(523)); + let unwrap_int = Integer::unwrap(value_random_int.clone()); + let try_from_int = + ::try_from(value_random_int).expect("should be ok to convert"); + + // tests for unwrap and try_from + let direct_int: Integer = 523.into(); + assert_eq!(direct_int, unwrap_int); + assert_eq!(direct_int, try_from_int); + + // test for error + let try_from_string_vec = ::try_from(Value::Char(Some('a'))) + .expect_err("should not be ok to convert char to stringvec"); + assert_eq!(try_from_string_vec.to_string(), ValueTypeErr.to_string()); +}