diff --git a/scaneo.go b/scaneo.go index f69b87e..44db4b8 100644 --- a/scaneo.go +++ b/scaneo.go @@ -74,8 +74,9 @@ NOTES ) type fieldToken struct { - Name string - Type string + Name string + Type string + IsArray bool } type structToken struct { @@ -260,6 +261,7 @@ func parseCode(source string, commaList string) ([]structToken, error) { } var fieldType string + var isArray bool // get field type switch typeToken := fieldLine.Type.(type) { @@ -272,6 +274,7 @@ func parseCode(source string, commaList string) ([]structToken, error) { case *ast.ArrayType: // arrays fieldType = parseArray(typeToken) + isArray = true case *ast.StarExpr: // pointers fieldType = parseStar(typeToken) @@ -284,6 +287,7 @@ func parseCode(source string, commaList string) ([]structToken, error) { // apply type to all variables declared in this line for i := range fieldToks { fieldToks[i].Type = fieldType + fieldToks[i].IsArray = isArray } structTok.Fields = append(structTok.Fields, fieldToks...) @@ -362,18 +366,29 @@ func genFile(outFile, pkg string, unexport bool, toks []structToken, genFuncs bo } defer fout.Close() + hasArrays := false + for _, tok := range toks { + for _, field := range tok.Fields { + if field.IsArray { + hasArrays = true + } + } + } + data := struct { - PackageName string - Tokens []structToken - Visibility string - Funcs bool - ImportPkg string + PackageName string + Tokens []structToken + Visibility string + Funcs bool + ImportPkg string + NeedDriverPkg bool }{ - PackageName: pkg, - Visibility: "S", - Tokens: toks, - Funcs: genFuncs, - ImportPkg: pkgImport, + PackageName: pkg, + Visibility: "S", + Tokens: toks, + Funcs: genFuncs, + ImportPkg: pkgImport, + NeedDriverPkg: hasArrays, } if unexport { diff --git a/scaneo_test.go b/scaneo_test.go index 38f6f6a..4c43d52 100644 --- a/scaneo_test.go +++ b/scaneo_test.go @@ -153,10 +153,10 @@ var ( { Name: "slices", Fields: []fieldToken{ - {Name: "a", Type: "[]bool"}, - {Name: "b", Type: "[]time.Time"}, - {Name: "c", Type: "[]*byte"}, - {Name: "d", Type: "[]*sql.NullString"}, + {Name: "a", Type: "[]bool", IsArray: true}, + {Name: "b", Type: "[]time.Time", IsArray: true}, + {Name: "c", Type: "[]*byte", IsArray: true}, + {Name: "d", Type: "[]*sql.NullString", IsArray: true}, }, }, { @@ -284,6 +284,13 @@ func TestParseCode(t *testing.T) { t.FailNow() } + if structToks[i].Fields[j].IsArray != toks[i].Fields[j].IsArray { + t.Error("unexpected isArray") + t.Error("file:", fPath) + t.Error("struct:", structToks[i].Name) + t.Errorf("expected: %v; found: %v\n", structToks[i].Fields[j].IsArray, toks[i].Fields[j].IsArray) + } + if structToks[i].Fields[j].Type != toks[i].Fields[j].Type { t.Error("unexpected struct field type") t.Error("file:", fPath) @@ -319,66 +326,66 @@ func TestGenFile(t *testing.T) { expectedFuncs []string }{ { - "no tokens", - true, - []structToken{}, - true, - false, - "", - func(t *testing.T, err error) { + name: "no tokens", + outFile: true, + tokens: []structToken{}, + unexport: true, + funcs: false, + pkgImport: "", + assert: func(t *testing.T, err error) { if err == nil { t.Error("no struct tokens passed") t.Error("should be error") t.FailNow() } }, - expectedFuncNames, + expectedFuncs: expectedFuncNames, }, { - "no output file", - false, - toks, - true, - false, - "", - func(t *testing.T, err error) { + name: "no output file", + outFile: false, + tokens: toks, + unexport: true, + funcs: false, + pkgImport: "", + assert: func(t *testing.T, err error) { if err == nil { t.Error("no output file path passed") t.Error("should be error") t.FailNow() } }, - expectedFuncNames, + expectedFuncs: expectedFuncNames, }, { - "scan funcs unexported", - true, - toks, - true, - false, - "", - func(t *testing.T, err error) { + name: "scan funcs unexported", + outFile: true, + tokens: toks, + unexport: true, + funcs: false, + pkgImport: "", + assert: func(t *testing.T, err error) { if err != nil { t.Error(err) t.FailNow() } }, - expectedFuncNames, + expectedFuncs: expectedFuncNames, }, { - "sql helper funcs", - true, - toks, - true, - true, - "", - func(t *testing.T, err error) { + name: "sql helper funcs", + outFile: true, + tokens: toks, + unexport: true, + funcs: true, + pkgImport: "", + assert: func(t *testing.T, err error) { if err != nil { t.Error(err) t.FailNow() } }, - []string{ + expectedFuncs: []string{ "scanExported", "scanExporteds", "SelectExported", @@ -396,22 +403,39 @@ func TestGenFile(t *testing.T) { }, }, { - "pkg import", - true, - toks, - true, - false, - "testsvc/storage/user", - func(t *testing.T, err error) { + name: "pkg import", + outFile: true, + tokens: toks, + unexport: true, + funcs: false, + pkgImport: "testsvc/storage/user", + assert: func(t *testing.T, err error) { if err != nil { t.Error(err) t.FailNow() } }, - expectedFuncNames, + expectedFuncs: expectedFuncNames, + }, + { + name: "arrays", + outFile: true, + tokens: fileStructsMap[testFiles[2]][4:5], + unexport: true, + funcs: false, + pkgImport: "", + assert: func(t *testing.T, err error) { + if err != nil { + t.Error(err) + t.FailNow() + } + }, + expectedFuncs: []string{ + "scanSlices", + "scanSlicess", + }, }, } - for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { var outFile string diff --git a/tmpl.go b/tmpl.go index 32abbd4..cbd7115 100644 --- a/tmpl.go +++ b/tmpl.go @@ -7,13 +7,18 @@ package {{.PackageName}} import ( "database/sql"{{ if .ImportPkg }} - "{{.ImportPkg}}"{{end}} + "{{.ImportPkg}}"{{end}}{{ if .NeedDriverPkg }} + "github.com/lib/pq"{{end}} ) {{range .Tokens}} func {{$.Visibility}}can{{title .Name}}(r *sql.Row) (*{{pkg .Name}}, error) { s := &{{pkg .Name}}{} if err := r.Scan({{range .Fields}} - &s.{{.Name}},{{end}} + {{if .IsArray -}} + pq.Array(&s.{{.Name}}) + {{- else -}} + &s.{{.Name}} + {{- end -}},{{end}} ); err != nil { return &{{pkg .Name}}{}, err } @@ -26,7 +31,11 @@ func {{$.Visibility}}can{{title .Name}}s(rs *sql.Rows) ([]*{{pkg .Name}}, error) for rs.Next() { s := &{{pkg .Name}}{} if err = rs.Scan({{range .Fields}} - &s.{{.Name}},{{end}} + {{if .IsArray -}} + pq.Array(&s.{{.Name}}) + {{- else -}} + &s.{{.Name}} + {{- end -}},{{end}} ); err != nil { return nil, err }