Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update reflectx to allow for optional nested structs #950

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package sqlx

import (
_ "unsafe"
)

//go:linkname convertAssign database/sql.convertAssign
func convertAssign(dest, src interface{}) error
3 changes: 1 addition & 2 deletions reflectx/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
v = reflect.Indirect(v).Field(i)
// if this is a pointer and it's nil, allocate a new value and set it
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
v.Set(reflect.New(v.Type().Elem()))
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
Expand Down
39 changes: 27 additions & 12 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ func (r *Rows) StructScan(dest interface{}) error {
r.started = true
}

err := fieldsByTraversal(v, r.fields, r.values, true)
err := fieldsByTraversal(v, r.fields, r.values)
if err != nil {
return err
}
Expand Down Expand Up @@ -784,7 +784,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error {
}
values := make([]interface{}, len(columns))

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
Expand Down Expand Up @@ -957,7 +957,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
vp = reflect.New(base)
v = reflect.Indirect(vp)

err = fieldsByTraversal(v, fields, values, true)
err = fieldsByTraversal(v, fields, values)
if err != nil {
return err
}
Expand Down Expand Up @@ -1023,7 +1023,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
// when iterating over many rows. Empty traversals will get an interface pointer.
// Because of the necessity of requesting ptrs or values, it's considered a bit too
// specialized for inclusion in reflectx itself.
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
Expand All @@ -1032,23 +1032,38 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else if len(traversal) == 1 {
values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface()
} else {
values[i] = f.Interface()
// reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs.
// Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct.
// That way we can support LEFT JOINs with optional nested structs.
traversal := traversal
values[i] = optDest(func() interface{} {
return reflectx.FieldByIndexes(v, traversal).Addr().Interface()
})
}
}
return nil
}

func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
func missingFields(traversals [][]int) (field int, err error) {
for i, t := range traversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}

// optDest will only forward the Scan to the nested value if
// the database value is not nil.
type optDest func() interface{}

// Scan implements sql.Scanner.
func (dest optDest) Scan(src interface{}) error {
if src == nil {
return nil
}
return convertAssign(dest(), src)
}
228 changes: 228 additions & 0 deletions sqlx_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,17 @@ func TestNamedQueryContext(t *testing.T) {
"FIRST" text NULL,
last_name text NULL,
"EMAIL" text NULL
);
CREATE TABLE persondetails (
email text NULL,
notes text NULL
);`,
drop: `
drop table person;
drop table jsperson;
drop table place;
drop table placeperson;
drop table persondetails;
`,
}

Expand Down Expand Up @@ -643,6 +648,229 @@ func TestNamedQueryContext(t *testing.T) {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID)
}
}

rows.Close()

type Owner struct {
Email *string `db:"email"`
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
}

// Test optional nested structs with left join
type PlaceOwner struct {
Place Place `db:"place"`
Owner *Owner `db:"owner"`
}

pl = Place{
Name: sql.NullString{String: "the-house", Valid: true},
}

q4 := `INSERT INTO place (id, name) VALUES (2, :name)`
_, err = db.NamedExecContext(ctx, q4, pl)
if err != nil {
log.Fatal(err)
}

id = 2
pp.Place.ID = id

q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`
_, err = db.NamedExecContext(ctx, q5, pp)
if err != nil {
log.Fatal(err)
}

pp3 := &PlaceOwner{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
place.id AS "place.id",
place.name AS "place.name",
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email"
FROM place
LEFT JOIN placeperson ON false -- null left join
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp3)
if err != nil {
t.Error(err)
}
if pp3.Owner != nil {
t.Error("Expected `Owner` to be nil")
}
if pp3.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String)
}
if pp3.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID)
}
}

rows.Close()

pp4 := &PlaceOwner{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
place.id AS "place.id",
place.name AS "place.name",
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email"
FROM place
LEFT JOIN placeperson ON placeperson.place_id = place.id
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp4)
if err != nil {
t.Error(err)
}
if pp4.Owner == nil {
t.Error("Expected `Owner` to not be nil")
}
if pp4.Owner.FirstName != "ben" {
t.Error("Expected first name of `ben`, got " + pp4.Owner.FirstName)
}
if pp4.Owner.LastName != "doe" {
t.Error("Expected first name of `doe`, got " + pp4.Owner.LastName)
}
if pp4.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp4.Place.Name.String)
}
if pp4.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp4.Place.ID)
}
}

type Details struct {
Email string `db:"email"`
Notes string `db:"notes"`
}

type OwnerDetails struct {
Email *string `db:"email"`
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
Details *Details `db:"details"`
}

type PlaceOwnerDetails struct {
Place Place `db:"place"`
Owner *OwnerDetails `db:"owner"`
}

pp5 := &PlaceOwnerDetails{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
place.id AS "place.id",
place.name AS "place.name",
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email",
persondetails.email "owner.details.email",
persondetails.notes "owner.details.notes"
FROM place
LEFT JOIN placeperson ON placeperson.place_id = place.id
LEFT JOIN persondetails ON false
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp5)
if err != nil {
t.Error(err)
}
if pp5.Owner == nil {
t.Error("Expected `Owner`, to not be nil")
}
if pp5.Owner.FirstName != "ben" {
t.Error("Expected first name of `ben`, got " + pp5.Owner.FirstName)
}
if pp5.Owner.LastName != "doe" {
t.Error("Expected first name of `doe`, got " + pp5.Owner.LastName)
}
if pp5.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp5.Place.Name.String)
}
if pp5.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp5.Place.ID)
}
if pp5.Owner.Details != nil {
t.Error("Expected `Details` to be nil")
}
}

details := Details{
Email: pp.Email.String,
Notes: "this is a test person",
}

q6 := `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)`
_, err = db.NamedExecContext(ctx, q6, details)
if err != nil {
log.Fatal(err)
}

pp6 := &PlaceOwnerDetails{}
rows, err = db.NamedQueryContext(ctx, `
SELECT
place.id AS "place.id",
place.name AS "place.name",
placeperson.first_name "owner.first_name",
placeperson.last_name "owner.last_name",
placeperson.email "owner.email",
persondetails.email "owner.details.email",
persondetails.notes "owner.details.notes"
FROM place
LEFT JOIN placeperson ON placeperson.place_id = place.id
LEFT JOIN persondetails ON persondetails.email = placeperson.email
WHERE
place.id=:place.id`, pp)
if err != nil {
log.Fatal(err)
}
for rows.Next() {
err = rows.StructScan(pp6)
if err != nil {
t.Error(err)
}
if pp6.Owner == nil {
t.Error("Expected `Owner` to not be nil")
}
if pp6.Owner.FirstName != "ben" {
t.Error("Expected first name of `ben`, got " + pp6.Owner.FirstName)
}
if pp6.Owner.LastName != "doe" {
t.Error("Expected first name of `doe`, got " + pp6.Owner.LastName)
}
if pp6.Place.Name.String != "the-house" {
t.Error("Expected place name of `the-house`, got " + pp6.Place.Name.String)
}
if pp6.Place.ID != pp.Place.ID {
t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp6.Place.ID)
}
if pp6.Owner.Details == nil {
t.Error("Expected `Details` to not be nil")
}
if pp6.Owner.Details.Email != details.Email {
t.Errorf("Expected details email of %v, got %v", details.Email, pp6.Owner.Details.Email)
}
if pp6.Owner.Details.Notes != details.Notes {
t.Errorf("Expected details notes of %v, got %v", details.Notes, pp6.Owner.Details.Notes)
}
}
})
}

Expand Down