Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive value type #1720

Merged
merged 24 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sea-orm-macros/src/derives/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod related_entity;
mod relation;
mod try_getable_from_json;
mod util;
mod value_type;

pub use active_enum::*;
pub use active_model::*;
Expand All @@ -31,3 +32,4 @@ pub use primary_key::*;
pub use related_entity::*;
pub use relation::*;
pub use try_getable_from_json::*;
pub use value_type::*;
240 changes: 240 additions & 0 deletions sea-orm-macros/src/derives/value_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned};
use syn::{spanned::Spanned, Lit, LitStr, Type};

struct DeriveValueType {
name: syn::Ident,
ty: Type,
column_type: TokenStream,
array_type: TokenStream,
}

impl DeriveValueType {
pub fn new(input: syn::DeriveInput) -> Option<Self> {
let dat = input.data;
let fields: Option<syn::punctuated::Punctuated<syn::Field, syn::token::Comma>> = match dat {
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }),
..
}) => Some(unnamed),
_ => None,
};

let field = fields
.expect("This derive accept only struct")
.first()
.expect("The struct should contain one value field")
.to_owned();

let name = input.ident;
let mut col_type = None;
let mut arr_type = None;

for attr in input.attrs.iter() {
if !attr.path().is_ident("sea_orm") {
continue;
}

attr.parse_nested_meta(|meta| {
if meta.path.is_ident("column_type") {
let lit = meta.value()?.parse()?;
if let Lit::Str(litstr) = lit {
let ty: TokenStream = syn::parse_str(&litstr.value())?;
col_type = Some(ty);
} else {
return Err(meta.error(format!("Invalid column_type {:?}", lit)));
}
} else if meta.path.is_ident("array_type") {
let lit = meta.value()?.parse()?;
if let Lit::Str(litstr) = lit {
let ty: TokenStream = syn::parse_str(&litstr.value())?;
arr_type = Some(ty);
} else {
return Err(meta.error(format!("Invalid array_type {:?}", lit)));
}
} else {
// received other attribute
return Err(meta.error(format!("Invalid attribute {:?}", meta.path)));
}

Ok(())
})
.unwrap_or(());
}

let ty = field.clone().ty;
let field_type = quote! { #ty }
.to_string() //E.g.: "Option < String >"
.replace(' ', ""); // Remove spaces
let field_type = if field_type.starts_with("Option<") {
&field_type[7..(field_type.len() - 1)] // Extract `T` out of `Option<T>`
} else {
field_type.as_str()
};

let column_type = match col_type {
Some(t) => quote! { sea_orm::sea_query::ColumnType::#t },
None => {
let col_type = match field_type {
"char" => quote! { Char(None) },
"String" | "&str" => quote! { String(None) },
"i8" => quote! { TinyInteger },
tyt2y3 marked this conversation as resolved.
Show resolved Hide resolved
"u8" => quote! { TinyUnsigned },
"i16" => quote! { SmallInteger },
"u16" => quote! { SmallUnsigned },
"i32" => quote! { Integer },
"u32" => quote! { Unsigned },
"i64" => quote! { BigInteger },
"u64" => quote! { BigUnsigned },
"f32" => quote! { Float },
"f64" => quote! { Double },
"bool" => quote! { Boolean },
"Date" | "NaiveDate" => quote! { Date },
"Time" | "NaiveTime" => quote! { Time },
"DateTime" | "NaiveDateTime" => {
quote! { DateTime }
}
"DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => {
quote! { TimestampWithTimeZone }
}
"Uuid" => quote! { Uuid },
"Json" => quote! { Json },
"Decimal" => quote! { Decimal(None) },
"Vec<u8>" => {
quote! { Binary(sea_orm::sea_query::BlobSize::Blob(None)) }
}
_ => {
// Assumed it's ActiveEnum if none of the above type matches
quote! {}
}
};
if col_type.is_empty() {
let field_span = field.span();
let ty: Type = LitStr::new(field_type, field_span)
.parse()
.expect("field type error");
let def = quote_spanned! { field_span =>
std::convert::Into::<sea_orm::ColumnType>::into(
<#ty as sea_orm::sea_query::ValueType>::column_type()
)
};
quote! { #def }
} else {
quote! { sea_orm::sea_query::ColumnType::#col_type }
}
}
};

let array_type = match arr_type {
Some(t) => quote! { sea_orm::sea_query::ArrayType::#t },
None => {
let arr_type = match field_type {
"char" => quote! { Char },
"String" | "&str" => quote! { String },
"i8" => quote! { TinyInt },
"u8" => quote! { TinyUnsigned },
"i16" => quote! { SmallInt },
"u16" => quote! { SmallUnsigned },
"i32" => quote! { Int },
"u32" => quote! { Unsigned },
"i64" => quote! { BigInt },
"u64" => quote! { BigUnsigned },
"f32" => quote! { Float },
"f64" => quote! { Double },
"bool" => quote! { Bool },
"Date" | "NaiveDate" => quote! { ChronoDate },
"Time" | "NaiveTime" => quote! { ChronoTime },
"DateTime" | "NaiveDateTime" => {
quote! { ChronoDateTime }
}
"DateTimeUtc" | "DateTimeLocal" | "DateTimeWithTimeZone" => {
quote! { ChronoDateTimeWithTimeZone }
}
"Uuid" => quote! { Uuid },
"Json" => quote! { Json },
"Decimal" => quote! { Decimal },
_ => {
// Assumed it's ActiveEnum if none of the above type matches
quote! {}
}
};
if arr_type.is_empty() {
let field_span = field.span();
let ty: Type = LitStr::new(field_type, field_span)
.parse()
.expect("field type error");
let def = quote_spanned! { field_span =>
std::convert::Into::<sea_orm::ArrayType>::into(
<#ty as sea_orm::sea_query::ValueType>::array_type()
)
};
quote! { #def }
} else {
quote! { sea_orm::sea_query::ArrayType::#arr_type }
}
}
};

Some(DeriveValueType {
name,
ty,
column_type,
array_type,
})
}

fn expand(&self) -> syn::Result<TokenStream> {
let expanded_impl_value_type: TokenStream = self.impl_value_type();
Ok(expanded_impl_value_type)
}

fn impl_value_type(&self) -> TokenStream {
let name = &self.name;
let field_type = &self.ty;
let column_type = &self.column_type;
let array_type = &self.array_type;

quote!(
#[automatically_derived]
impl From<#name> for Value {
fn from(source: #name) -> Self {
source.0.into()
}
}

#[automatically_derived]
impl sea_orm::TryGetable for #name {
fn try_get_by<I: sea_orm::ColIdx>(res: &QueryResult, idx: I) -> Result<Self, sea_orm::TryGetError> {
<#field_type as sea_orm::TryGetable>::try_get_by(res, idx).map(|v| #name(v))
}
}

#[automatically_derived]
impl sea_orm::sea_query::ValueType for #name {
fn try_from(v: Value) -> Result<Self, sea_orm::sea_query::ValueTypeErr> {
<#field_type as sea_orm::sea_query::ValueType>::try_from(v).map(|v| #name(v))
}

fn type_name() -> String {
stringify!(#name).to_owned()
}

fn array_type() -> sea_orm::sea_query::ArrayType {
#array_type
}

fn column_type() -> sea_orm::sea_query::ColumnType {
#column_type
}
}
)
}
}

pub fn expand_derive_value_type(input: syn::DeriveInput) -> syn::Result<TokenStream> {
let input_span = input.span();
match DeriveValueType::new(input) {
Some(model) => model.expand(),
None => Err(syn::Error::new(input_span, "error")),
}
}
10 changes: 10 additions & 0 deletions sea-orm-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,13 @@ pub fn enum_iter(input: TokenStream) -> TokenStream {
.unwrap_or_else(Error::into_compile_error)
.into()
}

#[cfg(feature = "derive")]
#[proc_macro_derive(DeriveValueType, attributes(sea_orm))]
pub fn derive_value_type(input: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
match derives::expand_derive_value_type(derive_input) {
Ok(token_stream) => token_stream.into(),
Err(e) => e.to_compile_error().into(),
}
}
1 change: 1 addition & 0 deletions tests/common/features/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod self_join;
pub mod teas;
pub mod transaction_log;
pub mod uuid_fmt;
pub mod value_type;

pub use active_enum::Entity as ActiveEnum;
pub use active_enum_child::Entity as ActiveEnumChild;
Expand Down
21 changes: 21 additions & 0 deletions tests/common/features/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> {
create_binary_table(db).await?;
create_bits_table(db).await?;
create_dyn_table_name_lazy_static_table(db).await?;
create_value_type_table(db).await?;

if DbBackend::Postgres == db_backend {
create_collection_table(db).await?;
Expand Down Expand Up @@ -634,3 +635,23 @@ pub async fn create_dyn_table_name_lazy_static_table(db: &DbConn) -> Result<(),

Ok(())
}

pub async fn create_value_type_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let stmt = sea_query::Table::create()
.table(value_type::Entity)
.col(
ColumnDef::new(value_type::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(value_type::Column::Number)
.integer()
.not_null(),
)
.to_owned();

create_table(db, &stmt, value_type::Entity).await
}
26 changes: 26 additions & 0 deletions tests/common/features/value_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use sea_orm::entity::prelude::*;
use sea_orm_macros::DeriveValueType;

#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
#[sea_orm(table_name = "value_type")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub number: Integer,
Copy link
Member

@tyt2y3 tyt2y3 Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strange bit here is that the StringVec is not included in the Model for testing.
Albeit it only works on Postgres.

May be we can simply change the test case https://github.com/SeaQL/sea-orm/blob/master/tests/common/features/json_vec.rs, because now DeriveValueType is the new, recommended API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the conversion of DeriveValueType rely on the pre-defined conversion from the field to Value.
I think Vec<String> currently does not have a defined conversion to any variant of Value, which I suppose it should.
(in json_vec, it convert itself to formatted String as Value, which I suppose isn't a good generic approach)
I'm wondering if I should implement From<Vec<T>> for Value or adjust the DeriveValueType implementation only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Sadly then it means that we can't 'eat our own dog food' for this test case.

}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}

#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
pub struct Integer(pub i32);

#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
#[sea_orm(column_type = "String(Some(1))", array_type = "String")]
pub struct StringVec(pub Vec<String>);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work. I think it looks really good. One question, is array_type required here? What if we omit it?

Can we actually change the test case in tests/common/features/json_vec.rs to use this DeriveValueType instead of the manual implementation?

Just want to make sure we have had end to end tests.


#[derive(Clone, Debug, PartialEq, Eq, DeriveValueType)]
#[sea_orm(column_type = "Boolean", array_type = "Bool")]
pub struct Boolbean(pub String);
Loading