From 942f629ec88c5e45928a2a90843bba356392c7e9 Mon Sep 17 00:00:00 2001 From: Chris Agocs Date: Wed, 13 Jul 2016 14:50:02 -0600 Subject: [PATCH 01/21] Add Insert function --- query.go | 18 ++++++++++++++---- reflect.go | 37 +++++++++++-------------------------- 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/query.go b/query.go index a770416..4dc796d 100644 --- a/query.go +++ b/query.go @@ -27,14 +27,24 @@ type Query struct { } // New creates a new Query object. The passed engine is used to format variables. The passed string is used to prefix the query. -func New(engine dbengine, query string) *Query { +func New(query string) *Query { return &Query{ - Engine: engine, - SQL: query, - Args: []interface{}{}, + SQL: query, + Args: []interface{}{}, } } +func Insert(obj interface{}, values ...interface{}) *Query { + columns, _ := getFields(obj) + query := New("INSERT INTO " + GetTableName(object) + "(" + strings.Join(columns, ", ") + ") VALUES ") + + for _, v := range values { + _, fieldValuesSlice := getFields(v) + query.Include("("+VariableList(len(fieldValuesSlice))+")", fieldValuesSlice) + } + return query +} + // WrongNumberArgsError is thrown when a Query is evaluated whose Args does not match its Expressions. type WrongNumberArgsError struct { NumExpected int diff --git a/reflect.go b/reflect.go index 259dbfd..871977c 100644 --- a/reflect.go +++ b/reflect.go @@ -56,7 +56,7 @@ func toSnake(s string) string { return snake } -func getFieldColumn(f reflect.StructField, quote bool) string { +func getFieldColumn(f reflect.StructField) string { // Get the SQL column name, from the tag or infer it field := f.Tag.Get(TAG_NAME) if field == "-" { @@ -65,9 +65,6 @@ func getFieldColumn(f reflect.StructField, quote bool) string { if field == "" || !validTag(field) { field = toSnake(f.Name) } - if quote { - field = "`" + field + "`" - } return field } @@ -79,7 +76,7 @@ func getFieldExpression(f reflect.StructField) string { return field } -func getFields(s sqlTableNamer, quoted, full, expressions bool) (fields []interface{}, values []interface{}) { +func getFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { t := reflect.TypeOf(s) v := reflect.ValueOf(s) k := t.Kind() @@ -96,29 +93,11 @@ func getFields(s sqlTableNamer, quoted, full, expressions bool) (fields []interf // skip unexported fields continue } - field := getFieldColumn(t.Field(i), quoted) - var expr bool - if field == "" { - if expressions { - expr = true - field = getFieldExpression(t.Field(i)) - } - } + field := getFieldColumn(t.Field(i)) if field == "" { continue } - if full && !expr { - var tablename string - if quoted { - tablename = "`" - } - tablename = tablename + s.GetSQLTableName() - if quoted { - tablename = tablename + "`" - } - tablename = tablename + "." - field = tablename + field - } + field = s.GetSQLTableName() + "." + field // Get the value of the field value := v.Field(i).Interface() @@ -141,7 +120,7 @@ func GetQuotedFieldsAndExpressions(s sqlTableNamer) (fields, values []interface{ } func GetFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { - return getFields(s, false, false, false) + return getFields(s) } // GetAbsoluteFields returns a slice of the fields in the passed type, with their names @@ -179,6 +158,12 @@ func GetUnquotedAbsoluteColumn(s sqlTableNamer, property string) string { return fmt.Sprintf("%s.%s", s.GetSQLTableName(), getColumn(s, property, false)) } +type ColumnList []string + +func (c ColumnList) String() string { + strings.Join(c, ", ") +} + func getColumn(s interface{}, property string, quote bool) string { t := reflect.TypeOf(s) k := t.Kind() From fe415956fc9c3db195d078650548ed14d4fc6c09 Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Wed, 13 Jul 2016 15:56:13 -0600 Subject: [PATCH 02/21] Continue rewrite, get to buildable with public API. Refactor until the public API matches what we sketched out (see docs/gophercon2016hackdaynotes.md) and successfully builds. Drop the genreadme.sh script, which we're not using anymore. --- docs/gophercon2016hackdaynotes.md | 62 ++++++++++ genreadme.sh | 2 - query.go | 199 +++++++++++------------------- reflect.go | 174 +++----------------------- 4 files changed, 150 insertions(+), 287 deletions(-) create mode 100644 docs/gophercon2016hackdaynotes.md delete mode 100755 genreadme.sh diff --git a/docs/gophercon2016hackdaynotes.md b/docs/gophercon2016hackdaynotes.md new file mode 100644 index 0000000..09ccef4 --- /dev/null +++ b/docs/gophercon2016hackdaynotes.md @@ -0,0 +1,62 @@ +q := pan.New("SELECT "+append(pan.Columns(post), "COUNT(*) FROM blah")+" FROM "+pan.Table(post)) +pan.New("SELECT "+strings.Join(pan.Columns(post), ", ")+" FROM "+pan.Table(post)) +q.Where().Comparison(post, "ID", "=", post.ID).Expression(" OR ").Comparison(post, "Name = ?", post.Name) +q.Expression(" OR ") +q.In(post, "Name", names) +q.OrderBy(post, "Name").OrderByDesc(post, "ID").Limit(maxResults) +-------- +pan.Insert(post, ...values) +q := pan.NewExpression("INSERT flkfs into whaosh").Values(post, post, post) +.Values("klsdjglsjdg", 0, "alkghdskjghs") + +func New(expression string) *Query { +} + +func Insert(obj interface{}, values ...interface{}) *Query { +} + +func Table(obj interface{}) string {} + +func Column(obj interface{}, property string) string {} + +func Columns(obj interface{}) Columns {} + +type Columns []string +func (c Columns) String() string { + return strings.Join(c, ",") +} + +func (q *Query) Where() *Query { +} + +func (q *Query) Comparison(obj interface{}, property, operator string, value interface{}) *Query { +} + +func (q *Query) Expression(expression string, args ...interface{}) *Query { +} + +func (q *Query) In(obj interface{}, property string, values interface{}) *Query { +} + +func (q *Query) OrderBy(obj interface{}, property string) *Query { +} + +func (q *Query) OrderByDesc(obj interface{}, property string) *Query { +} + +func (q *Query) Limit(max int) *Query { +} + +func (q *Query) Offset(max int) *Query { +} + +func (q *Query) PostgreSQLString() string { +} + +func (q *Query) MySQLString() string { +} + +func (q *Query) Args() []interface{} + +func (q *Query) String() string { +} diff --git a/genreadme.sh b/genreadme.sh deleted file mode 100755 index 4969e90..0000000 --- a/genreadme.sh +++ /dev/null @@ -1,2 +0,0 @@ -echo "[![Build Status](https://travis-ci.org/secondbit/pan.png)](https://travis-ci.org/secondbit/pan)" > README.md -godoc2md github.com/DramaFever/pan >> README.md diff --git a/query.go b/query.go index 4dc796d..ff88378 100644 --- a/query.go +++ b/query.go @@ -3,44 +3,40 @@ package pan import ( "fmt" "math" - "reflect" "strings" "unicode/utf8" ) -type dbengine int - -const ( - MYSQL dbengine = iota - POSTGRES -) - // Query contains the data needed to perform a single SQL query. type Query struct { - SQL string - Args []interface{} - Expressions []string - IncludesWhere bool - IncludesOrder bool - IncludesLimit bool - Engine dbengine + sql string + args []interface{} + expressions []string + includesWhere bool + includesOrder bool +} + +type ColumnList []string + +func (c ColumnList) String() string { + return strings.Join(c, ", ") } // New creates a new Query object. The passed engine is used to format variables. The passed string is used to prefix the query. func New(query string) *Query { return &Query{ - SQL: query, - Args: []interface{}{}, + sql: query, + args: []interface{}{}, } } -func Insert(obj interface{}, values ...interface{}) *Query { - columns, _ := getFields(obj) - query := New("INSERT INTO " + GetTableName(object) + "(" + strings.Join(columns, ", ") + ") VALUES ") +func Insert(obj SQLTableNamer, values ...SQLTableNamer) *Query { + columns := Columns(obj) + query := New("INSERT INTO " + Table(obj) + "(" + columns.String() + ") VALUES ") for _, v := range values { - _, fieldValuesSlice := getFields(v) - query.Include("("+VariableList(len(fieldValuesSlice))+")", fieldValuesSlice) + columnValues := ColumnValues(v) + query.Expression("("+VariableList(len(columnValues))+")", columnValues) } return query } @@ -57,43 +53,21 @@ func (e WrongNumberArgsError) Error() string { } func (q *Query) checkCounts() error { - placeholders := strings.Count(q.SQL, "?") - args := len(q.Args) + placeholders := strings.Count(q.sql, "?") + args := len(q.args) if placeholders != args { return WrongNumberArgsError{NumExpected: placeholders, NumFound: args} } return nil } -// Generate creates a string from the Query, joining its SQL property and its Expressions. Expressions are joined -// using the join string supplied. -func (q *Query) Generate(join string) string { - if len(q.Expressions) > 0 { - q.FlushExpressions(join) - } - return q.String() -} - -// String fulfills the String interface for Queries, and returns the generated SQL query after every instance of ? -// has been replaced with a counter prefixed with $ (e.g., $1, $2, $3). There is no support for using ?, quoted or not, -// within an expression. All instances of the ? character that are not meant to be substitutions should be as arguments -// in prepared statements. func (q *Query) String() string { - if err := q.checkCounts(); err != nil { - return "" - } - var output string - switch q.Engine { - case POSTGRES: - output = q.postgresProcess() - case MYSQL: - output = q.mysqlProcess() - } - return output + // TODO(paddy): return the query with values injected + return "" } func (q *Query) mysqlProcess() string { - return q.SQL + ";" + return q.sql + ";" } func (q *Query) postgresProcess() string { @@ -104,8 +78,8 @@ func (q *Query) postgresProcess() string { replacementRune, _ := utf8.DecodeRune([]byte(replacementString)) terminatorString := ";" terminatorBytes := []byte(terminatorString) - toReplace := float64(strings.Count(q.SQL, replacementString)) - bytesNeeded := float64(len(q.SQL) + len(replacementString)) + toReplace := float64(strings.Count(q.sql, replacementString)) + bytesNeeded := float64(len(q.sql) + len(replacementString)) powerCounter := float64(1) powerMax := math.Pow(10, powerCounter) - 1 prevMax := float64(0) @@ -118,8 +92,8 @@ func (q *Query) postgresProcess() string { bytesNeeded += ((toReplace - prevMax) * powerCounter) output := make([]byte, int(bytesNeeded)) buffer := make([]byte, utf8.UTFMax) - for pos < len(q.SQL) { - r, width = utf8.DecodeRuneInString(q.SQL[pos:]) + for pos < len(q.sql) { + r, width = utf8.DecodeRuneInString(q.sql[pos:]) pos += width if r == replacementRune { newText := []byte(fmt.Sprintf("$%d", count)) @@ -142,100 +116,73 @@ func (q *Query) postgresProcess() string { return string(output) } -// FlushExpressions joins the Query's Expressions with the join string, then concatenates them -// to the Query's SQL. It then resets the Query's Expressions. This permits Expressions to be joined -// by different strings within a single Query. -func (q *Query) FlushExpressions(join string) *Query { - q.SQL = strings.TrimSpace(q.SQL) + " " - q.SQL += strings.TrimSpace(strings.Join(q.Expressions, join)) - q.Expressions = q.Expressions[0:0] +func (q *Query) Flush(join string) *Query { + q.sql = strings.TrimSpace(q.sql) + " " + q.sql += strings.TrimSpace(strings.Join(q.expressions, join)) + q.expressions = q.expressions[0:0] return q } -// IncludeIfNotNil adds the supplied key (which should be an expression) to the Query's Expressions if -// and only if the value parameter is not a nil value. If the key is added to the Query's Expressions, the -// value is added to the Query's Args. -func (q *Query) IncludeIfNotNil(key string, value interface{}) *Query { - val := reflect.ValueOf(value) - kind := val.Kind() - if kind == reflect.Chan || kind == reflect.Func { - return q - } - if kind != reflect.Map && kind != reflect.Slice && kind != reflect.Interface && kind != reflect.Ptr { - return q.IncludeIfNotEmpty(key, value) - } - if val.IsNil() { - return q - } - q.Expressions = append(q.Expressions, key) - q.Args = append(q.Args, value) +func (q *Query) Expression(key string, values ...interface{}) *Query { + q.expressions = append(q.expressions, key) + q.args = append(q.args, values...) return q } -// IncludeIfNotEmpty adds the supplied key (which should be an expression) to the Query's Expressions if -// and only if the value parameter is not the empty value for its type. If the key is added to the Query's -// Expressions, the value is added to the Query's Args. -func (q *Query) IncludeIfNotEmpty(key string, value interface{}) *Query { - if reflect.DeepEqual(value, reflect.Zero(reflect.TypeOf(value)).Interface()) { +func (q *Query) Where() *Query { + if q.includesWhere { return q } - q.Expressions = append(q.Expressions, key) - q.Args = append(q.Args, value) + q.Expression("WHERE") + q.Flush(" ") + q.includesWhere = true return q } -// Include adds the supplied key (which should be an expression) to the Query's Expressions and the value -// to the Query's Args. -func (q *Query) Include(key string, values ...interface{}) *Query { - q.Expressions = append(q.Expressions, key) - q.Args = append(q.Args, values...) - return q +func (q *Query) Comparison(obj SQLTableNamer, property, operator string, value interface{}) *Query { + return q.Expression(Column(obj, property)+" "+operator+" ?", value) } -// IncludeWhere includes the WHERE clause if the WHERE clause has not already been included in the Query. -// This cannot detect WHERE clauses that are manually added to the Query's SQL; it only tracks IncludeWhere(). -func (q *Query) IncludeWhere() *Query { - if q.IncludesWhere { - return q - } - q.Expressions = append(q.Expressions, "WHERE") - q.FlushExpressions(" ") - q.IncludesWhere = true - return q +func (q *Query) In(obj SQLTableNamer, property string, values ...interface{}) *Query { + return q.Expression(Column(obj, property)+" IN("+VariableList(len(values))+")", values...) } -// IncludeOrder includes the ORDER BY clause if the ORDER BY clause has not already been included in the Query. -// This cannot detect ORDER BY clauses that are manually added to the Query's SQL; it only tracks IncludeOrder(). -// The passed string is used as the expression to order by. -func (q *Query) IncludeOrder(orderClause string) *Query { - if q.IncludesOrder { - return q +func (q *Query) orderBy(orderClause, dir string) *Query { + exp := ", " + if !q.includesOrder { + exp = "ORDER BY " } - q.Expressions = append(q.Expressions, "ORDER BY "+orderClause) - q.IncludesOrder = true + q.Expression(exp + orderClause + dir) + q.includesOrder = true return q } -// IncludeLimit includes the LIMIT clause if the LIMIT clause has not already been included in the Query. -// This cannot detect LIMIT clauses that are manually added to the Query's SQL; it only tracks IncludeLimit(). -// The passed int is used as the limit in the resulting query. -func (q *Query) IncludeLimit(limit int64) *Query { - if q.IncludesLimit { - return q - } - q.Expressions = append(q.Expressions, "LIMIT ?") - q.Args = append(q.Args, limit) - q.IncludesLimit = true - return q +func (q *Query) OrderBy(column string) *Query { + return q.orderBy(column, "") } -func (q *Query) IncludeOffset(offset int64) *Query { - q.Expressions = append(q.Expressions, "OFFSET ?") - q.Args = append(q.Args, offset) - return q +func (q *Query) OrderByDesc(column string) *Query { + return q.orderBy(column, " DESC") } -func (q *Query) InnerJoin(table, expression string) *Query { - q.Expressions = append(q.Expressions, "INNER JOIN "+table+" ON "+expression) - return q +func (q *Query) Limit(limit int64) *Query { + return q.Expression("LIMIT ?", limit) +} + +func (q *Query) Offset(offset int64) *Query { + return q.Expression("OFFSET ?", offset) +} + +func (q *Query) PostgreSQLString() (string, error) { + // TODO(paddy): return the PostgreSQL formatted q.sql + return "", nil +} + +func (q *Query) MySQLString() (string, error) { + // TODO(paddy): return the MySQL formatted q.sql + return "", nil +} + +func (q *Query) Args() []interface{} { + return q.args } diff --git a/reflect.go b/reflect.go index 871977c..054b9a6 100644 --- a/reflect.go +++ b/reflect.go @@ -1,7 +1,6 @@ package pan import ( - "fmt" "reflect" "strings" "unicode" @@ -9,8 +8,7 @@ import ( ) const ( - TAG_NAME = "sql_column" // The tag that will be read - ExpressionTag = "sql_expression" + TAG_NAME = "sql_column" // The tag that will be read ) func validTag(s string) bool { @@ -68,15 +66,7 @@ func getFieldColumn(f reflect.StructField) string { return field } -func getFieldExpression(f reflect.StructField) string { - field := f.Tag.Get(ExpressionTag) - if field == "-" { - return "" - } - return field -} - -func getFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { +func readStruct(s SQLTableNamer) (columns []string, values []interface{}) { t := reflect.TypeOf(s) v := reflect.ValueOf(s) k := t.Kind() @@ -101,70 +91,29 @@ func getFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { // Get the value of the field value := v.Field(i).Interface() - fields = append(fields, field) + columns = append(columns, field) values = append(values, value) } return } -// GetQuotedFields returns a slice of the fields in the passed type, with their names -// drawn from tags or inferred from the property name (which will be lower-cased with underscores, -// e.g. CamelCase => camel_case) and a corresponding slice of interface{}s containing the values for -// those properties. Fields will be surrounding in ` marks. -func GetQuotedFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { - return getFields(s, true, false, false) -} - -func GetQuotedFieldsAndExpressions(s sqlTableNamer) (fields, values []interface{}) { - return getFields(s, true, false, true) -} - -func GetFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { - return getFields(s) -} - -// GetAbsoluteFields returns a slice of the fields in the passed type, with their names -// drawn from tags or inferred from the property name (which will be lower-cased with underscores, -// e.g. CamelCase => camel_case) and a corresponding slice of interface{}s containing the values for -// those properties. Fields will be surrounded in \` marks and prefixed with their table name, as -// determined by the passed type's GetSQLTableName. The format will be \`table_name\`.\`field_name\`. -func GetAbsoluteFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { - return getFields(s, true, true, false) -} - -func GetAbsoluteFieldsAndExpressions(s sqlTableNamer) (fields, values []interface{}) { - return getFields(s, true, true, true) -} - -func GetUnquotedAbsoluteFields(s sqlTableNamer) (fields []interface{}, values []interface{}) { - return getFields(s, false, true, false) +func Columns(s SQLTableNamer) ColumnList { + columns, _ := readStruct(s) + return columns } // GetColumn returns the field name associated with the specified property in the passed value. // Property must correspond exactly to the name of the property in the type, or this function will // panic. -func GetColumn(s interface{}, property string) string { - return getColumn(s, property, true) -} - -// GetAbsoluteColumnName returns the field name associated with the specified property in the passed value. -// Property must correspond exactly to the name of the property in the type, or this function will -// panic. -func GetAbsoluteColumnName(s sqlTableNamer, property string) string { - return fmt.Sprintf("`%s`.%s", GetTableName(s), GetColumn(s, property)) +func Column(s SQLTableNamer, property string) string { + return getColumn(s, property) } -func GetUnquotedAbsoluteColumn(s sqlTableNamer, property string) string { - return fmt.Sprintf("%s.%s", s.GetSQLTableName(), getColumn(s, property, false)) +func ColumnValues(s SQLTableNamer) []interface{} { + return nil } -type ColumnList []string - -func (c ColumnList) String() string { - strings.Join(c, ", ") -} - -func getColumn(s interface{}, property string, quote bool) string { +func getColumn(s interface{}, property string) string { t := reflect.TypeOf(s) k := t.Kind() for k == reflect.Interface || k == reflect.Ptr { @@ -178,21 +127,17 @@ func getColumn(s interface{}, property string, quote bool) string { if !ok { panic("Field not found in type: " + property) } - return getFieldColumn(field, quote) + return getFieldColumn(field) } -func GetUnquotedColumn(s interface{}, property string) string { - return getColumn(s, property, false) -} - -type sqlTableNamer interface { +type SQLTableNamer interface { GetSQLTableName() string } // GetTableName returns the table name for any type that implements the `GetSQLTableName() string` // method signature. The returned string will be used as the name of the table to store the data // for all instances of the type. -func GetTableName(t sqlTableNamer) string { +func Table(t SQLTableNamer) string { return t.GetSQLTableName() } @@ -216,92 +161,6 @@ func QueryList(fields []interface{}) string { return strings.Join(strs, ", ") } -// GetM2MTableName returns a consistent table name for a many-to-many relationship between two tables. No -// matter what order the fields are passed in, the resulting table name will always be consistent. -func GetM2MTableName(t1, t2 sqlTableNamer) string { - name1 := t1.GetSQLTableName() - name2 := t2.GetSQLTableName() - if name2 < name1 { - name1, name2 = name2, name1 - } - return fmt.Sprintf("%s_%s", name1, name2) -} - -// GetM2MAbsoluteColumnName returns the column name for the supplied field in a many-to-many relationship table, -// including the table name. The field belongs to the first sqlTableNamer, the second sqlTableNamer is the other -// table in the many-to-many relationship. -func GetM2MAbsoluteColumnName(t sqlTableNamer, field string, t2 sqlTableNamer) string { - return fmt.Sprintf("`%s`.%s", GetM2MTableName(t, t2), GetM2MQuotedColumnName(t, field)) -} - -// GetM2MColumnName returns the column name for the supplied field in a many-to-many relationship table. -func GetM2MColumnName(t sqlTableNamer, field string) string { - return fmt.Sprintf("%s_%s", t.GetSQLTableName(), GetUnquotedColumn(t, field)) -} - -// GetM2MQuotedColumnName returns the column name for the supplied field in a many-to-many relationship table, -// including the quote marks around the column name. -func GetM2MQuotedColumnName(t sqlTableNamer, field string) string { - return fmt.Sprintf("`%s`", GetM2MColumnName(t, field)) -} - -// GetM2MFields returns a slice of the columns that should be in a table that maps the many-to-many relationship of -// the types supplied, with their corresponding values. The field parameters specify the primary keys used in -// the relationship table to map to that type. -func GetM2MFields(t1 sqlTableNamer, field1 string, t2 sqlTableNamer, field2 string) (columns, values []interface{}) { - type1 := reflect.TypeOf(t1) - type2 := reflect.TypeOf(t2) - value1 := reflect.ValueOf(t1) - value2 := reflect.ValueOf(t2) - kind1 := value1.Kind() - kind2 := value2.Kind() - for kind1 == reflect.Interface || kind1 == reflect.Ptr { - value1 = value1.Elem() - type1 = value1.Type() - kind1 = value1.Kind() - } - for kind2 == reflect.Interface || kind2 == reflect.Ptr { - value2 = value2.Elem() - type2 = value2.Type() - kind2 = value2.Kind() - } - if kind1 != reflect.Struct { - panic("Can't get fields of " + type1.Name()) - } - if kind2 != reflect.Struct { - panic("Can't get fields of " + type2.Name()) - } - v1 := value1.FieldByName(field1) - v2 := value2.FieldByName(field2) - if v1 == *new(reflect.Value) { - panic(`No "` + field1 + `" field found in ` + type1.Name()) - } - if v2 == *new(reflect.Value) { - panic(`No "` + field2 + `" field found in ` + type2.Name()) - } - column1 := GetM2MColumnName(t1, field1) - column2 := GetM2MColumnName(t2, field2) - if column2 < column1 { - type1, type2 = type2, type1 - value1, value2 = value2, value1 - kind1, kind2 = kind2, kind1 - v1, v2 = v2, v1 - column1, column2 = column2, column1 - } - columns = append(columns, column1, column2) - values = append(values, v1.Interface(), v2.Interface()) - return -} - -// GetM2MQuotedFields wraps the fields returned by GetM2MFields in quotes. -func GetM2MQuotedFields(t1 sqlTableNamer, field1 string, t2 sqlTableNamer, field2 string) (columns, values []interface{}) { - columns, values = GetM2MFields(t1, field1, t2, field2) - for pos, column := range columns { - columns[pos] = "`" + column.(string) + "`" - } - return -} - type Scannable interface { Scan(dest ...interface{}) error } @@ -324,10 +183,7 @@ func Unmarshal(s Scannable, dst interface{}) error { // skip unexported fields continue } - field := getFieldColumn(t.Field(i), true) - if field == "" { - field = getFieldExpression(t.Field(i)) - } + field := getFieldColumn(t.Field(i)) if field == "" { continue } From bde6be753c26b88509bd13e1f464529ea0f9761b Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Thu, 14 Jul 2016 09:59:46 -0600 Subject: [PATCH 03/21] Finish refactoring to the desired API. Finish refactoring to achieve our desired API, and update our tests to pass. So far, our only problems are the Unmarshal tests, because I'd like to think of a way to get around including sqlite, or at least make them an integration test. --- query.go | 49 +++++----- query_test.go | 253 +++++++++++++++++++++++++----------------------- reflect.go | 7 +- reflect_test.go | 236 +++++++------------------------------------- sql_test.go | 53 +++++----- 5 files changed, 227 insertions(+), 371 deletions(-) diff --git a/query.go b/query.go index ff88378..3785ab1 100644 --- a/query.go +++ b/query.go @@ -31,24 +31,27 @@ func New(query string) *Query { } func Insert(obj SQLTableNamer, values ...SQLTableNamer) *Query { + inserts := make([]SQLTableNamer, 0, len(values)+1) + inserts = append(inserts, obj) + inserts = append(inserts, values...) columns := Columns(obj) - query := New("INSERT INTO " + Table(obj) + "(" + columns.String() + ") VALUES ") + query := New("INSERT INTO " + Table(obj) + " (" + columns.String() + ") VALUES") - for _, v := range values { + for _, v := range inserts { columnValues := ColumnValues(v) - query.Expression("("+VariableList(len(columnValues))+")", columnValues) + query.Expression("("+VariableList(len(columnValues))+")", columnValues...) } - return query + return query.Flush(" ") } -// WrongNumberArgsError is thrown when a Query is evaluated whose Args does not match its Expressions. -type WrongNumberArgsError struct { +// ErrWrongNumberArgs is thrown when a Query is evaluated whose Args does not match its Expressions. +type ErrWrongNumberArgs struct { NumExpected int NumFound int } // Error fulfills the error interface, returning the expected number of arguments and the number supplied. -func (e WrongNumberArgsError) Error() string { +func (e ErrWrongNumberArgs) Error() string { return fmt.Sprintf("Expected %d arguments, got %d.", e.NumExpected, e.NumFound) } @@ -56,7 +59,7 @@ func (q *Query) checkCounts() error { placeholders := strings.Count(q.sql, "?") args := len(q.args) if placeholders != args { - return WrongNumberArgsError{NumExpected: placeholders, NumFound: args} + return ErrWrongNumberArgs{NumExpected: placeholders, NumFound: args} } return nil } @@ -66,11 +69,17 @@ func (q *Query) String() string { return "" } -func (q *Query) mysqlProcess() string { - return q.sql + ";" +func (q *Query) MySQLString() (string, error) { + if err := q.checkCounts(); err != nil { + return "", err + } + return q.sql + ";", nil } -func (q *Query) postgresProcess() string { +func (q *Query) PostgreSQLString() (string, error) { + if err := q.checkCounts(); err != nil { + return "", err + } var pos, width, outputPos int var r rune var count = 1 @@ -113,7 +122,7 @@ func (q *Query) postgresProcess() string { for i := 0; i < len(terminatorBytes); i++ { output[len(output)-(len(terminatorBytes)-i)] = terminatorBytes[i] } - return string(output) + return string(output), nil } func (q *Query) Flush(join string) *Query { @@ -147,13 +156,17 @@ func (q *Query) In(obj SQLTableNamer, property string, values ...interface{}) *Q return q.Expression(Column(obj, property)+" IN("+VariableList(len(values))+")", values...) } +func (q *Query) Assign(obj SQLTableNamer, property string, value interface{}) *Query { + return q.Expression(Column(obj, property)+" = ?", value) +} + func (q *Query) orderBy(orderClause, dir string) *Query { exp := ", " if !q.includesOrder { exp = "ORDER BY " + q.includesOrder = true } q.Expression(exp + orderClause + dir) - q.includesOrder = true return q } @@ -173,16 +186,6 @@ func (q *Query) Offset(offset int64) *Query { return q.Expression("OFFSET ?", offset) } -func (q *Query) PostgreSQLString() (string, error) { - // TODO(paddy): return the PostgreSQL formatted q.sql - return "", nil -} - -func (q *Query) MySQLString() (string, error) { - // TODO(paddy): return the MySQL formatted q.sql - return "", nil -} - func (q *Query) Args() []interface{} { return q.args } diff --git a/query_test.go b/query_test.go index d6592cb..e049e85 100644 --- a/query_test.go +++ b/query_test.go @@ -6,84 +6,128 @@ import ( ) type queryTest struct { - ExpectedResult string + ExpectedResult queryResult Query *Query } +type queryResult struct { + postgres string + mysql string + err error +} + var queryTests = []queryTest{ queryTest{ - ExpectedResult: "This query expects $1 one arg;", + ExpectedResult: queryResult{ + postgres: "This query expects $1 one arg;", + mysql: "This query expects ? one arg;", + err: nil, + }, Query: &Query{ - SQL: "This query expects ? one arg", - Args: []interface{}{0}, - Engine: POSTGRES, + sql: "This query expects ? one arg", + args: []interface{}{0}, }, }, queryTest{ - ExpectedResult: "", + ExpectedResult: queryResult{ + postgres: "", + mysql: "", + err: ErrWrongNumberArgs{ + NumExpected: 1, + NumFound: 0, + }, + }, Query: &Query{ - SQL: "This query expects ? one arg but won't get it;", - Args: []interface{}{}, - Engine: POSTGRES, + sql: "This query expects ? one arg but won't get it;", + args: []interface{}{}, }, }, queryTest{ - ExpectedResult: "", + ExpectedResult: queryResult{ + postgres: "", + mysql: "", + err: ErrWrongNumberArgs{ + NumExpected: 0, + NumFound: 1, + }, + }, Query: &Query{ - SQL: "This query expects no arguments but will get one;", - Args: []interface{}{0}, - Engine: POSTGRES, + sql: "This query expects no arguments but will get one;", + args: []interface{}{0}, }, }, queryTest{ - ExpectedResult: "", + ExpectedResult: queryResult{ + postgres: "", + mysql: "", + err: ErrWrongNumberArgs{ + NumExpected: 2, + NumFound: 1, + }, + }, Query: &Query{ - SQL: "This query expects ? two args ? but will get one;", - Args: []interface{}{0}, - Engine: POSTGRES, + sql: "This query expects ? two args ? but will get one;", + args: []interface{}{0}, }, }, queryTest{ - ExpectedResult: "", + ExpectedResult: queryResult{ + postgres: "", + mysql: "", + err: ErrWrongNumberArgs{ + NumExpected: 2, + NumFound: 3, + }, + }, Query: &Query{ - SQL: "This query expects ? ? two args but will get three;", - Args: []interface{}{0, 1, 2}, - Engine: POSTGRES, + sql: "This query expects ? ? two args but will get three;", + args: []interface{}{0, 1, 2}, }, }, queryTest{ - ExpectedResult: "Unicode test 世 $1;", + ExpectedResult: queryResult{ + postgres: "Unicode test 世 $1;", + mysql: "Unicode test 世 ?;", + err: nil, + }, Query: &Query{ - SQL: "Unicode test 世 ?", - Args: []interface{}{0}, - Engine: POSTGRES, + sql: "Unicode test 世 ?", + args: []interface{}{0}, }, }, queryTest{ - ExpectedResult: "Unicode boundary test $1 " + string(rune(0x80)) + ";", + ExpectedResult: queryResult{ + postgres: "Unicode boundary test $1 " + string(rune(0x80)) + ";", + mysql: "Unicode boundary test ? " + string(rune(0x80)) + ";", + err: nil, + }, Query: &Query{ - SQL: "Unicode boundary test ? " + string(rune(0x80)), - Args: []interface{}{0}, - Engine: POSTGRES, + sql: "Unicode boundary test ? " + string(rune(0x80)), + args: []interface{}{0}, }, }, } func init() { - expected := "lots of args" - SQL := "lots of args" + postgres := "lots of args" + mysql := "lots of args" + sql := "lots of args" args := []interface{}{} for i := 1; i < 1001; i++ { - SQL += " ?" - expected += fmt.Sprintf(" $%d", i) + sql += " ?" + mysql += " ?" + postgres += fmt.Sprintf(" $%d", i) args = append(args, false) if i == 10 || i == 100 || i == 1000 { queryTests = append(queryTests, queryTest{ - ExpectedResult: expected + ";", + ExpectedResult: queryResult{ + mysql: mysql + ";", + postgres: postgres + ";", + err: nil, + }, Query: &Query{ - SQL: SQL, - Args: args, - Engine: POSTGRES, + sql: sql, + args: args, }, }) } @@ -92,23 +136,33 @@ func init() { func TestQueriesFromTable(t *testing.T) { for pos, test := range queryTests { - result := test.Query.String() - if result != test.ExpectedResult { - t.Logf("Expected\n%d\ngot\n%d\n.", []byte(test.ExpectedResult), []byte(result)) - t.Errorf("Query test %d failed. Expected \"%s\", got \"%s\".", pos+1, test.ExpectedResult, result) + t.Logf("Testing: %s", test.Query.sql) + mysql, mErr := test.Query.MySQLString() + postgres, pErr := test.Query.PostgreSQLString() + if (mErr != nil && pErr == nil) || (pErr != nil && mErr == nil) || (mErr != nil && pErr != nil && mErr.Error() != pErr.Error()) { + t.Errorf("Expected %v and %v to be the same.\n", mErr, pErr) + } + if (mErr != nil && test.ExpectedResult.err == nil) || (mErr == nil && test.ExpectedResult.err != nil) || (mErr != nil && test.ExpectedResult.err != nil && mErr.Error() != test.ExpectedResult.err.Error()) { + t.Errorf("Expected error to be %v, got %v\n", mErr, test.ExpectedResult.err) + } + if mysql != test.ExpectedResult.mysql { + t.Errorf("Query test %d failed. Expected MySQL to be \"%s\", got \"%s\".\n", pos+1, test.ExpectedResult.mysql, mysql) + } + if postgres != test.ExpectedResult.postgres { + t.Errorf("Query test %d failed: Expected PostgreSQL to be \"%s\", got \"%s\".\n", pos+1, test.ExpectedResult.postgres, postgres) } } } -func TestWrongNumberArgsError(t *testing.T) { - q := New(POSTGRES, "?") - q.Args = append(q.Args, 1, 2, 3) +func TestErrWrongNumberArgs(t *testing.T) { + q := New("?") + q.args = append(q.args, 1, 2, 3) err := q.checkCounts() if err == nil { t.Errorf("Expected error.") } - if e, ok := err.(WrongNumberArgsError); !ok { - t.Errorf("Error was not a WrongNumberArgsError.") + if e, ok := err.(ErrWrongNumberArgs); !ok { + t.Errorf("Error was not an ErrWrongNumberArgs.") } else { if e.NumExpected != 1 { t.Errorf("Expected %d expectations, got %d", 1, e.NumExpected) @@ -122,99 +176,56 @@ func TestWrongNumberArgsError(t *testing.T) { } } -func TestIncludeIfNotNil(t *testing.T) { - q := New(POSTGRES, "") - q.IncludeIfNotNil("hello ?", "world") - if q.Generate("") != " hello $1;" { - t.Errorf("Expected `%s`, got `%s`", " hello $1;", q.Generate("")) - } - - var val *testType - q = New(POSTGRES, "") - q.IncludeIfNotNil("hello ?", val) - if q.Generate("") != ";" { - t.Errorf("Expected `%s`, got `%s`", ";", q.Generate("")) - } - - q = New(POSTGRES, "") - q.IncludeIfNotNil("hello ?", New) - if q.Generate("") != ";" { - t.Errorf("Expected `%s`, got `%s`", ";", q.Generate("")) - } -} - func TestRepeatedOrder(t *testing.T) { - q := New(POSTGRES, "SELECT * FROM test_data") - q.IncludeOrder("id DESC") - q.IncludeOrder("date DESC") - if q.Generate(" ") != "SELECT * FROM test_data ORDER BY id DESC;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data ORDER BY id DESC;", q.Generate(" ")) + q := New("SELECT * FROM test_data") + q.OrderBy("id") + q.OrderBy("name") + q.OrderByDesc("date") + res, err := q.Flush(" ").MySQLString() + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) } -} - -func TestRepeatedLimit(t *testing.T) { - q := New(POSTGRES, "SELECT * FROM test_data") - q.IncludeLimit(10) - q.IncludeLimit(5) - if q.Generate(" ") != "SELECT * FROM test_data LIMIT $1;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data LIMIT $1;", q.Generate(" ")) + if res != "SELECT * FROM test_data ORDER BY id , name , date DESC;" { + t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data ORDER BY id , name , date DESC;", res) } } -// I'm worried that setting the SQL property in String() will cause a problem, so let's find out. -// I was right. Leaving this test here as penance. -func TestRepeatedString(t *testing.T) { - q := New(POSTGRES, "SELECT * FROM test_data") - q.IncludeWhere() - q.Include("x = ?", 1) - q.IncludeLimit(10) - q.FlushExpressions(" ") - q2 := *q - qstr := q.String() - if qstr != "SELECT * FROM test_data WHERE x = $1 LIMIT $2;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data WHERE x = $1 LIMIT $2;", qstr) +func TestOffset(t *testing.T) { + q := New("SELECT * FROM test_data") + q.Offset(10).Flush(" ") + mysql, err := q.MySQLString() + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) } - qstr = q.String() - if qstr != "SELECT * FROM test_data WHERE x = $1 LIMIT $2;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data WHERE x = $1 LIMIT $2;", qstr) + postgres, err := q.PostgreSQLString() + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) } - q2.Engine = MYSQL - q2str := q2.String() - if q2str != "SELECT * FROM test_data WHERE x = ? LIMIT ?;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data WHERE x = ? LIMIT ?;", q2str) + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) } - q2str = q2.String() - if q2str != "SELECT * FROM test_data WHERE x = ? LIMIT ?;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data WHERE x = ? LIMIT ?;", q2str) + if mysql != "SELECT * FROM test_data OFFSET ?;" { + t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data OFFSET ?;", mysql) } -} - -func TestIncludeOffset(t *testing.T) { - q := New(POSTGRES, "SELECT * FROM test_data") - q.IncludeOffset(10) - if q.Generate(" ") != "SELECT * FROM test_data OFFSET $1;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data OFFSET $1;", q.Generate(" ")) + if postgres != "SELECT * FROM test_data OFFSET $1;" { + t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data OFFSET $1;", postgres) } } -func TestInnerJoin(t *testing.T) { - q := New(POSTGRES, "SELECT * FROM test_data") - q.InnerJoin("other_table", "`test_data`.`x` = `other_table`.`x`") - if q.Generate(" ") != "SELECT * FROM test_data INNER JOIN other_table ON `test_data`.`x` = `other_table`.`x`;" { - t.Errorf("Expected `%s`, got `%s`", "SELECT * FROM test_data INNER JOIN other_table ON `test_data`.`x` = `other_table`.`x`;", q.Generate(" ")) +func BenchmarkMySQLQueriesFromTable(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + test := queryTests[b.N%len(queryTests)] + b.StartTimer() + test.Query.MySQLString() } } -func BenchmarkQueriesFromTable(b *testing.B) { +func BenchmarkPostgreSQLQueriesFromTable(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() test := queryTests[b.N%len(queryTests)] b.StartTimer() - result := test.Query.String() - b.StopTimer() - if result != test.ExpectedResult { - b.Errorf("Query test %d failed. Expected \"%s\", got \"%s\".", (b.N%len(queryTests))+1, test.ExpectedResult, result) - } - b.StartTimer() + test.Query.PostgreSQLString() } } diff --git a/reflect.go b/reflect.go index 054b9a6..d5e1683 100644 --- a/reflect.go +++ b/reflect.go @@ -67,8 +67,8 @@ func getFieldColumn(f reflect.StructField) string { } func readStruct(s SQLTableNamer) (columns []string, values []interface{}) { - t := reflect.TypeOf(s) v := reflect.ValueOf(s) + t := reflect.TypeOf(s) k := t.Kind() for k == reflect.Interface || k == reflect.Ptr { v = v.Elem() @@ -106,11 +106,14 @@ func Columns(s SQLTableNamer) ColumnList { // Property must correspond exactly to the name of the property in the type, or this function will // panic. func Column(s SQLTableNamer, property string) string { + // BUG(paddy): return the table name as part of the column name + // when we put in a cache for reflect, we should refactor this to use readStruct return getColumn(s, property) } func ColumnValues(s SQLTableNamer) []interface{} { - return nil + _, values := readStruct(s) + return values } func getColumn(s interface{}, property string) string { diff --git a/reflect_test.go b/reflect_test.go index 5991457..69bdc0f 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -1,12 +1,6 @@ package pan -import ( - "database/sql" - "os" - "testing" - - _ "github.com/mattn/go-sqlite3" -) +import "testing" type testType struct { myInt int @@ -14,7 +8,6 @@ type testType struct { MyString string myTaggedString string `sql_column:"tagged_string"` OmittedColumn string `sql_column:"-"` - Expression string `sql_column:"-" sql_expression:"COUNT(*)"` } func (t testType) GetSQLTableName() string { @@ -35,38 +28,26 @@ func TestReflectedProperties(t *testing.T) { MyTaggedInt: 2, MyString: "hello", myTaggedString: "world", - Expression: "bar", } - fields, values := GetAbsoluteFields(foo) - if len(fields) != 2 { - t.Errorf("Fields should have length %d, has length %d", 2, len(fields)) + columns := Columns(foo) + if len(columns) != 2 { + t.Errorf("Columns should have length %d, has length %d", 2, len(columns)) } + values := ColumnValues(foo) if len(values) != 2 { t.Errorf("Values should have length %d, has length %d", 2, len(values)) } - for pos, field := range fields { - if field != "`test_types`.`tagged_int`" && field != "`test_types`.`my_string`" { - t.Errorf("Unknown field found: % v'", field) + for pos, column := range columns { + if column != "test_types.tagged_int" && column != "test_types.my_string" { + t.Errorf("Unknown column found: %v'", column) } - if field == "`test_types`.`tagged_int`" && values[pos].(int) != 2 { + if column == "test_types.tagged_int" && values[pos].(int) != 2 { t.Errorf("Expected tagged_int to be %d, got %v", 2, values[pos]) } - if field == "`test_types`.`my_string`" && values[pos].(string) != "hello" { + if column == "test_types.my_string" && values[pos].(string) != "hello" { t.Errorf("Expected my_string to be %s, got %v", "hello", values[pos]) } } - fields, values = GetAbsoluteFieldsAndExpressions(foo) - if len(fields) != 3 { - t.Errorf("Fields should have length %d, has length %d", 3, len(fields)) - } - if len(values) != 3 { - t.Errorf("Values should have length %d, has length %d", 3, len(values)) - } - for _, field := range fields { - if field != "`test_types`.`tagged_int`" && field != "`test_types`.`my_string`" && field != "COUNT(*)" { - t.Errorf("Unknown field found: % v'", field) - } - } } var tags = map[string]bool{ @@ -115,9 +96,10 @@ func (i invalidSqlFieldReflector) GetSQLTableName() string { } func TestInvalidFieldReflection(t *testing.T) { - fields, values := getFields(invalidSqlFieldReflector("test"), true, false, false) - if len(fields) != 0 { - t.Errorf("Expected %d fields, got %d.", 0, len(fields)) + columns := Columns(invalidSqlFieldReflector("test")) + values := ColumnValues(invalidSqlFieldReflector("test")) + if len(columns) != 0 { + t.Errorf("Expected %d columns, got %d.", 0, len(columns)) } if len(values) != 0 { t.Errorf("Expected %d values, got %d.", 0, len(values)) @@ -125,207 +107,55 @@ func TestInvalidFieldReflection(t *testing.T) { } func TestInterfaceOrPointerFieldReflection(t *testing.T) { - fields, values := getFields(&testType{}, false, false, false) - if len(fields) != 2 { - t.Errorf("Expected %d fields, but got %v", len(fields), fields) + columns := Columns(&testType{}) + if len(columns) != 2 { + t.Errorf("Expected %d columns, but got %v", len(columns), columns) } + values := ColumnValues(&testType{}) if len(values) != 2 { t.Errorf("Expected %d values, but got %v", len(values), values) } - var i sqlTableNamer + var i SQLTableNamer i = testType{} - fields, values = getFields(i, false, false, false) - if len(fields) != 2 { - t.Errorf("Expected %d fields, but got %v", len(fields), fields) + columns = Columns(i) + if len(columns) != 2 { + t.Errorf("Expected %d columns, but got %v", len(columns), columns) } + values = ColumnValues(i) if len(values) != 2 { t.Errorf("Expected %d values, but got %v", len(values), values) } i = &testType{} - fields, values = getFields(i, false, false, false) - if len(fields) != 2 { - t.Errorf("Expected %d fields, but got %v", len(fields), fields) + columns = Columns(i) + if len(columns) != 2 { + t.Errorf("Expected %d columns, but got %v", len(columns), columns) } + values = ColumnValues(i) if len(values) != 2 { t.Errorf("Expected %d values, but got %v", len(values), values) } } func TestInvalidColumnTypes(t *testing.T) { - result := GetColumn("", "test") - if result != "" { - t.Errorf("Expected column name to be `%s`, was `%s`.", "", result) - } - defer func() { t.Log(recover()) }() - result = GetColumn(&testType{}, "NotARealProperty") + result := Column(&testType{}, "NotARealProperty") t.Errorf("Expected a panic, got `%s` instead.", result) } func TestOmittedColumn(t *testing.T) { - fields, _ := GetQuotedFields(&testType{}) - for _, field := range fields { - if field.(string) == "`omitted_column`" { + columns := Columns(&testType{}) + for _, column := range columns { + if column == "omitted_column" { t.Errorf("omitted_column should not have shown up, but it did.") } } } -func TestGetM2MTableName(t *testing.T) { - tableName := GetM2MTableName(testType{}, testType2{}) - if tableName != "more_tests_test_types" { - t.Errorf("Expected `%s`, got `%s`", "more_tests_test_types", tableName) - } - tableName2 := GetM2MTableName(testType2{}, testType{}) - if tableName2 != "more_tests_test_types" { - t.Errorf("Expected `%s`, got `%s`", "more_tests_test_types", tableName2) - } - if tableName != tableName2 { - t.Errorf("`%s` is not equal to `%s`", tableName, tableName2) - } -} - -func TestGetAbsoluteColumnName(t *testing.T) { - columnName := GetAbsoluteColumnName(testType{}, "MyString") - if columnName != "`test_types`.`my_string`" { - t.Errorf("Expected `%s`, got `%s`", "`test_types`.`my_string`", columnName) - } -} - -func TestGetM2MColumnName(t *testing.T) { - columnName := GetM2MColumnName(testType{}, "MyString") - if columnName != "test_types_my_string" { - t.Errorf("Expected `%s`, got `%s`", "test_types_my_string", columnName) - } -} - -func TestGetM2MQuotedColumnName(t *testing.T) { - columnName := GetM2MQuotedColumnName(testType{}, "MyString") - if columnName != "`test_types_my_string`" { - t.Errorf("Expected %s, got %s", "`test_types_my_string`", columnName) - } -} - -func TestGetM2MAbsoluteColumnName(t *testing.T) { - columnName := GetM2MAbsoluteColumnName(testType{}, "MyString", testType2{}) - if columnName != "`more_tests_test_types`.`test_types_my_string`" { - t.Errorf("Expected %s, got %s", "`more_tests_test_types`.`test_types_my_string`", columnName) - } -} - -func TestGetM2MFields(t *testing.T) { - t1 := testType{MyString: "hello"} - t2 := testType2{ID: "world"} - columns, values := GetM2MFields(t1, "MyString", t2, "ID") - columns2, values2 := GetM2MFields(t2, "ID", t1, "MyString") - columns3, values3 := GetM2MFields(&t1, "MyString", &t2, "ID") - columns4, values4 := GetM2MFields(sqlTableNamer(t1), "MyString", sqlTableNamer(t2), "ID") - columns5, values5 := GetM2MFields(sqlTableNamer(&t1), "MyString", sqlTableNamer(&t2), "ID") - columns6, _ := GetM2MQuotedFields(t1, "MyString", t2, "ID") - if columns[0].(string) != "more_tests_id" { - t.Errorf("Expected %s, got %s", "more_tests_id", columns[0]) - } - if columns[1].(string) != "test_types_my_string" { - t.Errorf("Expected %s, got %s", "test_types_my_string", columns[1]) - } - if columns2[0].(string) != "more_tests_id" { - t.Errorf("Expected %s, got %s", "more_tests_id", columns2[0]) - } - if columns2[1].(string) != "test_types_my_string" { - t.Errorf("Expected %s, got %s", "test_types_my_string", columns2[1]) - } - if values[0].(string) != "world" { - t.Errorf("Expected %s, got %s", "world", values[0].(string)) - } - if values[1].(string) != "hello" { - t.Errorf("Expected %s, got %s", "hello", values[1].(string)) - } - if values2[0].(string) != "world" { - t.Errorf("Expected %s, got %s", "world", values2[0].(string)) - } - if values2[1].(string) != "hello" { - t.Errorf("Expected %s, got %s", "hello", values2[1].(string)) - } - if columns3[0].(string) != "more_tests_id" { - t.Errorf("Expected %s, got %s", "more_tests_id", columns[0]) - } - if columns3[1].(string) != "test_types_my_string" { - t.Errorf("Expected %s, got %s", "test_types_my_string", columns[1]) - } - if columns4[0].(string) != "more_tests_id" { - t.Errorf("Expected %s, got %s", "more_tests_id", columns2[0]) - } - if columns4[1].(string) != "test_types_my_string" { - t.Errorf("Expected %s, got %s", "test_types_my_string", columns2[1]) - } - if values3[0].(string) != "world" { - t.Errorf("Expected %s, got %s", "world", values3[0].(string)) - } - if values3[1].(string) != "hello" { - t.Errorf("Expected %s, got %s", "hello", values3[1].(string)) - } - if values4[0].(string) != "world" { - t.Errorf("Expected %s, got %s", "world", values4[0].(string)) - } - if values4[1].(string) != "hello" { - t.Errorf("Expected %s, got %s", "hello", values4[1].(string)) - } - if columns5[0].(string) != "more_tests_id" { - t.Errorf("Expected %s, got %s", "more_tests_id", columns5[0]) - } - if columns5[1].(string) != "test_types_my_string" { - t.Errorf("Expected %s, got %s", "test_types_my_string", columns5[1]) - } - if values5[0].(string) != "world" { - t.Errorf("Expected %s, got %s", "world", values5[0].(string)) - } - if values5[1].(string) != "hello" { - t.Errorf("Expected %s, got %s", "hello", values5[1].(string)) - } - if columns6[0].(string) != "`more_tests_id`" { - t.Errorf("Expected %s, got %s", "`more_tests_id`", columns6[0].(string)) - } - if columns6[1].(string) != "`test_types_my_string`" { - t.Errorf("Expected %s, got %s", "`test_types_my_string`", columns6[1].(string)) - } -} - -func TestInvalidM2MFieldTypes1(t *testing.T) { - defer func() { - t.Log(recover()) - }() - fields, values := GetM2MFields(&testType{}, "NotARealProperty", &testType2{}, "ID") - t.Errorf("Expected a panic, got `%v` and `%v` instead.", fields, values) -} - -func TestInvalidM2MFieldTypes2(t *testing.T) { - defer func() { - t.Log(recover()) - }() - fields, values := GetM2MFields(&testType{}, "MyString", &testType2{}, "NotARealProperty") - t.Errorf("Expected a panic, got `%v` and `%v` instead.", fields, values) -} - -func TestNonStructM2MFieldTypes1(t *testing.T) { - defer func() { - t.Log(recover()) - }() - fields, values := GetM2MFields(invalidSqlFieldReflector("test"), "mystring", &testType2{}, "ID") - t.Errorf("Expected a panic, got `%v` and `%v` instead.", fields, values) -} - -func TestNonStructM2MFieldTypes2(t *testing.T) { - defer func() { - t.Log(recover()) - }() - fields, values := GetM2MFields(&testType{}, "MyString", invalidSqlFieldReflector("test"), "ID") - t.Errorf("Expected a panic, got `%v` and `%v` instead.", fields, values) -} - +/* func TestUnmarshal(t *testing.T) { os.Remove("./test.db") @@ -370,4 +200,4 @@ func TestUnmarshal(t *testing.T) { t.Errorf("Expected MyString to be %s, was %s.", dummy.MyString, expectation.MyString) } os.Remove("./test.db") -} +}*/ diff --git a/sql_test.go b/sql_test.go index 434beed..97e5c95 100644 --- a/sql_test.go +++ b/sql_test.go @@ -20,33 +20,42 @@ func (t testPost) GetSQLTableName() string { func init() { p := testPost{123, "my post", 1, "this is a test post", time.Now(), nil} - fields, values := GetQuotedFields(p) - sqlTable[New(MYSQL, "INSERT").Include("INTO ?", GetTableName(p)).Include("("+VariableList(len(fields))+")", []interface{}(fields)...).Include("VALUES").Include("("+VariableList(len(values))+")", values...).Generate(" ")] = "INSERT INTO ? (?,?,?,?,?,?) VALUES (?,?,?,?,?,?);" - sqlTable[New(POSTGRES, "INSERT").Include("INTO ?", GetTableName(p)).Include("("+VariableList(len(fields))+")", []interface{}(fields)...).Include("VALUES").Include("("+VariableList(len(values))+")", values...).Generate(" ")] = "INSERT INTO $1 ($2,$3,$4,$5,$6,$7) VALUES ($8,$9,$10,$11,$12,$13);" - sqlTable[New(POSTGRES, "UPDATE "+GetTableName(p)).Include("SET").FlushExpressions(" ").IncludeIfNotEmpty(GetColumn(p, "Title")+" = ?", p.Title).IncludeIfNotNil(GetColumn(p, "Modified")+" = ?", p.Modified).Include(GetColumn(p, "Author")+" = ?", p.Author).FlushExpressions(", ").IncludeWhere().Include("? = ?", GetColumn(p, "ID"), p).Generate(" ")] = "UPDATE test_data SET `title` = $1, `author_id` = $2 WHERE $3 = $4;" - sqlTable[New(MYSQL, "UPDATE "+GetTableName(p)).Include("SET").FlushExpressions(" ").IncludeIfNotEmpty(GetColumn(p, "Title")+" = ?", p.Title).IncludeIfNotNil(GetColumn(p, "Modified")+" = ?", p.Modified).Include(GetColumn(p, "Author")+" = ?", p.Author).FlushExpressions(", ").IncludeWhere().Include("? = ?", GetColumn(p, "ID"), p).Generate(" ")] = "UPDATE test_data SET `title` = ?, `author_id` = ? WHERE ? = ?;" - p.Modified = &p.Created - p.Title = "" - fields, values = GetAbsoluteFields(p) - sqlTable[New(POSTGRES, "UPDATE "+GetTableName(p)).Include("SET").FlushExpressions(" ").IncludeIfNotEmpty(GetColumn(p, "Title")+" = ?", p.Title).IncludeIfNotNil(GetColumn(p, "Modified")+" = ?", p.Modified).Include(GetColumn(p, "Author")+" = ?", p.Author).FlushExpressions(", ").IncludeWhere().Include("? = ?", GetColumn(p, "ID"), p).Generate(" ")] = "UPDATE test_data SET `modified` = $1, `author_id` = $2 WHERE $3 = $4;" - sqlTable[New(MYSQL, "UPDATE "+GetTableName(p)).Include("SET").FlushExpressions(" ").IncludeIfNotEmpty(GetColumn(p, "Title")+" = ?", p.Title).IncludeIfNotNil(GetColumn(p, "Modified")+" = ?", p.Modified).Include(GetColumn(p, "Author")+" = ?", p.Author).FlushExpressions(", ").IncludeWhere().Include("? = ?", GetColumn(p, "ID"), p).Generate(" ")] = "UPDATE test_data SET `modified` = ?, `author_id` = ? WHERE ? = ?;" - sqlTable[New(POSTGRES, "SELECT "+QueryList(fields)).Include("FROM ?", GetTableName(p)).IncludeWhere().Include(GetColumn(p, "Created")+" > (SELECT "+GetColumn(p, "Created")+" FROM `"+GetTableName(p)+"` WHERE "+GetColumn(p, "ID")+" = ?)", 123).IncludeWhere().IncludeOrder(GetColumn(p, "Created")+" DESC").IncludeLimit(19).Generate(" ")] = "SELECT `test_data`.`id`, `test_data`.`title`, `test_data`.`author_id`, `test_data`.`body`, `test_data`.`created`, `test_data`.`modified` FROM $1 WHERE `created` > (SELECT `created` FROM `test_data` WHERE `id` = $2) ORDER BY `created` DESC LIMIT $3;" - sqlTable[New(MYSQL, "SELECT "+QueryList(fields)).Include("FROM ?", GetTableName(p)).IncludeWhere().Include(GetColumn(p, "Created")+" > (SELECT "+GetColumn(p, "Created")+" FROM `"+GetTableName(p)+"` WHERE "+GetColumn(p, "ID")+" = ?)", 123).IncludeWhere().IncludeOrder(GetColumn(p, "Created")+" DESC").IncludeLimit(19).Generate(" ")] = "SELECT `test_data`.`id`, `test_data`.`title`, `test_data`.`author_id`, `test_data`.`body`, `test_data`.`created`, `test_data`.`modified` FROM ? WHERE `created` > (SELECT `created` FROM `test_data` WHERE `id` = ?) ORDER BY `created` DESC LIMIT ?;" + sqlTable[Insert(p)] = queryResult{ + mysql: "INSERT INTO test_data (test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified) VALUES (?,?,?,?,?,?);", + postgres: "INSERT INTO test_data (test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified) VALUES ($1,$2,$3,$4,$5,$6);", + } + sqlTable[New("UPDATE "+Table(p)+" SET").Assign(p, "Title", p.Title).Assign(p, "Author", p.Author).Flush(", ").Where().Comparison(p, "ID", "=", p.ID).Flush(" ")] = queryResult{ + mysql: "UPDATE test_data SET title = ?, author_id = ? WHERE id = ?;", + postgres: "UPDATE test_data SET title = $1, author_id = $2 WHERE id = $3;", + } + sqlTable[New("SELECT "+Columns(p).String()+" FROM "+Table(p)).Where().Expression(Column(p, "Created")+" > (SELECT "+Column(p, "Created")+" FROM "+Table(p)+" WHERE "+Column(p, "ID")+" = ?)", 123).Where().OrderByDesc(Column(p, "Created")).Limit(19).Flush(" ")] = queryResult{ + postgres: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = $1) ORDER BY created DESC LIMIT $2;", + mysql: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = ?) ORDER BY created DESC LIMIT ?;", + } } -var sqlTable = map[string]string{ - New(MYSQL, "INSERT").Include("INTO ?", GetTableName(testPost{})).Include("("+VariableList(4)+")", "a", "b", "c", "d").Include("VALUES").Include("("+VariableList(4)+")", 0, 1, 2, 3).Generate(" "): "INSERT INTO ? (?,?,?,?) VALUES (?,?,?,?);", - New(POSTGRES, "INSERT").Include("INTO ?", GetTableName(testPost{})).Include("("+VariableList(4)+")", "a", "b", "c", "d").Include("VALUES").Include("("+VariableList(4)+")", 0, 1, 2, 3).Generate(" "): "INSERT INTO $1 ($2,$3,$4,$5) VALUES ($6,$7,$8,$9);", +var sqlTable = map[*Query]queryResult{ + New("INSERT INTO "+Table(testPost{})).Expression("("+VariableList(4)+")", "a", "b", "c", "d").Expression("VALUES").Expression("("+VariableList(4)+")", 0, 1, 2, 3).Flush(" "): { + mysql: "INSERT INTO test_data (?,?,?,?) VALUES (?,?,?,?);", + postgres: "INSERT INTO test_data ($1,$2,$3,$4) VALUES ($5,$6,$7,$8);", + }, } func TestSQLTable(t *testing.T) { - for output, expectation := range sqlTable { - if output != expectation { - if output == "" { - t.Errorf("Expected %s, but there was an argument count error.", expectation) - } else { - t.Errorf("Expected '%s' got '%s'", expectation, output) - } + for query, expectation := range sqlTable { + mysql, err := query.MySQLString() + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) + } + postgres, err := query.PostgreSQLString() + if err != nil { + t.Errorf("Unexpected error: %+v\n", err) + } + if mysql != expectation.mysql { + t.Errorf("Expected '%s' got '%s'", expectation.mysql, mysql) + } + if postgres != expectation.postgres { + t.Errorf("Expected '%s' got '%s'", expectation.postgres, postgres) } } } From 1b9b233e171087effe64b3da4129f9ec7a0a50c0 Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Thu, 14 Jul 2016 14:46:13 -0600 Subject: [PATCH 04/21] Implement Query.String. Implement the String method for query, to return the SQL (with values injected) that would be executed on the server. This is not quite perfect yet--we're using the fmt.Sprintf `%v` output for each arg right now, instead of the actual SQL syntax, but this should be good enough to give people an idea when debugging. We could improve on this by, e.g., quoting strings. I'd like, one day, to have that output be copy/paste-able valid SQL, but, as we say to death, not today. I also added a benchmark for the String function, and renamed our current benchmarks, to make room for future benchmarking. --- query.go | 18 ++++++++++++++++-- query_test.go | 13 +++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/query.go b/query.go index 3785ab1..0dac54e 100644 --- a/query.go +++ b/query.go @@ -65,8 +65,22 @@ func (q *Query) checkCounts() error { } func (q *Query) String() string { - // TODO(paddy): return the query with values injected - return "" + var argPos int + var res string + toCheck := q.sql + for i := strings.Index(toCheck, "?"); i >= 0; argPos++ { + var arg interface{} + arg = "!{MISSING}" + if len(q.args) > argPos { + arg = q.args[argPos] + } + res += toCheck[:i] + res += fmt.Sprintf("%v", arg) + toCheck = toCheck[i+1:] + i = strings.Index(toCheck, "?") + } + res += toCheck + return res } func (q *Query) MySQLString() (string, error) { diff --git a/query_test.go b/query_test.go index e049e85..6256ca0 100644 --- a/query_test.go +++ b/query_test.go @@ -212,7 +212,7 @@ func TestOffset(t *testing.T) { } } -func BenchmarkMySQLQueriesFromTable(b *testing.B) { +func BenchmarkMySQLString(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() test := queryTests[b.N%len(queryTests)] @@ -221,7 +221,7 @@ func BenchmarkMySQLQueriesFromTable(b *testing.B) { } } -func BenchmarkPostgreSQLQueriesFromTable(b *testing.B) { +func BenchmarkPostgreSQLString(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() test := queryTests[b.N%len(queryTests)] @@ -229,3 +229,12 @@ func BenchmarkPostgreSQLQueriesFromTable(b *testing.B) { test.Query.PostgreSQLString() } } + +func BenchmarkQueryString(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StopTimer() + test := queryTests[b.N%len(queryTests)] + b.StartTimer() + test.Query.String() + } +} From c3c99c8ea4468495574159de4b319a72c2d6733a Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Thu, 14 Jul 2016 14:53:54 -0600 Subject: [PATCH 05/21] Benchmark query generation for query and insert. Benchmark how long it takes to build a query object when querying, and when inserting. --- sql_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql_test.go b/sql_test.go index 97e5c95..892a896 100644 --- a/sql_test.go +++ b/sql_test.go @@ -43,6 +43,7 @@ var sqlTable = map[*Query]queryResult{ func TestSQLTable(t *testing.T) { for query, expectation := range sqlTable { + t.Logf(query.String()) mysql, err := query.MySQLString() if err != nil { t.Errorf("Unexpected error: %+v\n", err) @@ -59,3 +60,17 @@ func TestSQLTable(t *testing.T) { } } } + +func BenchmarkInsertGeneration(b *testing.B) { + p := testPost{123, "my post", 1, "this is a test post", time.Now(), nil} + for i := 0; i < b.N; i++ { + Insert(p) + } +} + +func BenchmarkQueryGeneration(b *testing.B) { + p := testPost{123, "my post", 1, "this is a test post", time.Now(), nil} + for i := 0; i < b.N; i++ { + New("SELECT "+Columns(p).String()+" FROM "+Table(p)).Where().Comparison(p, "ID", "=", p.ID).Expression("OR").In(p, "ID", 123, 456, 789, 101112, 131415).OrderBy(Column(p, "Created")).Limit(10).Flush(" ") + } +} From cfeb3d3e70206878ad03849e9357a93b6320e42e Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Thu, 14 Jul 2016 14:58:54 -0600 Subject: [PATCH 06/21] Speed up PostgreSQLString generation. By removing our hand-crafted byte management and memory allocation, and relying on the strings.Index method and building the string section by section (like we do for the String method) we get a roughly 100% performance improvement. Thanks, benchmarks! --- query.go | 55 +++++++++++-------------------------------------------- 1 file changed, 11 insertions(+), 44 deletions(-) diff --git a/query.go b/query.go index 0dac54e..ae39093 100644 --- a/query.go +++ b/query.go @@ -2,9 +2,8 @@ package pan import ( "fmt" - "math" + "strconv" "strings" - "unicode/utf8" ) // Query contains the data needed to perform a single SQL query. @@ -94,49 +93,17 @@ func (q *Query) PostgreSQLString() (string, error) { if err := q.checkCounts(); err != nil { return "", err } - var pos, width, outputPos int - var r rune - var count = 1 - replacementString := "?" - replacementRune, _ := utf8.DecodeRune([]byte(replacementString)) - terminatorString := ";" - terminatorBytes := []byte(terminatorString) - toReplace := float64(strings.Count(q.sql, replacementString)) - bytesNeeded := float64(len(q.sql) + len(replacementString)) - powerCounter := float64(1) - powerMax := math.Pow(10, powerCounter) - 1 - prevMax := float64(0) - for powerMax < toReplace { - bytesNeeded += ((powerMax - prevMax) * powerCounter) - prevMax = powerMax - powerCounter += 1 - powerMax = math.Pow(10, powerCounter) - 1 - } - bytesNeeded += ((toReplace - prevMax) * powerCounter) - output := make([]byte, int(bytesNeeded)) - buffer := make([]byte, utf8.UTFMax) - for pos < len(q.sql) { - r, width = utf8.DecodeRuneInString(q.sql[pos:]) - pos += width - if r == replacementRune { - newText := []byte(fmt.Sprintf("$%d", count)) - for _, b := range newText { - output[outputPos] = b - outputPos += 1 - } - count += 1 - continue - } - used := utf8.EncodeRune(buffer[0:], r) - for b := 0; b < used; b++ { - output[outputPos] = buffer[b] - outputPos += 1 - } - } - for i := 0; i < len(terminatorBytes); i++ { - output[len(output)-(len(terminatorBytes)-i)] = terminatorBytes[i] + count := 1 + var res string + toCheck := q.sql + for i := strings.Index(toCheck, "?"); i >= 0; count++ { + res += toCheck[:i] + res += "$" + strconv.Itoa(count) + toCheck = toCheck[i+1:] + i = strings.Index(toCheck, "?") } - return string(output), nil + res += toCheck + return res + ";", nil } func (q *Query) Flush(join string) *Query { From 2148707c96e3d2c21135c8a1d3f98c1f6185423f Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Thu, 14 Jul 2016 17:54:42 -0600 Subject: [PATCH 07/21] Use a cache for reflection. Only read the column names from a struct using reflection once, then cache them; then we only re-read them when we need to retrieve values. I also dropped the buffer clearing from toSnake, because we only use the bytes we overwrote anyways, so there's really no point in zeroing them out. The cache brings between 10-50% improvement on cache times for inserts and queries, respectively. Obviously, it will change depending on your queries, but the more you read column names, the bigger the improvement will be, and the more you read values, the smaller the improvement will be. --- reflect.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/reflect.go b/reflect.go index d5e1683..814e31f 100644 --- a/reflect.go +++ b/reflect.go @@ -1,6 +1,7 @@ package pan import ( + "fmt" "reflect" "strings" "unicode" @@ -11,6 +12,10 @@ const ( TAG_NAME = "sql_column" // The tag that will be read ) +var ( + structReadCache = map[string][]string{} +) + func validTag(s string) bool { if s == "" { return false @@ -46,10 +51,6 @@ func toSnake(s string) string { n := utf8.EncodeRune(buf, c) snake += string(buf[0:n]) - // clear the buffer - for i := 0; i < n; i++ { - buf[i] = 0 - } } return snake } @@ -66,7 +67,12 @@ func getFieldColumn(f reflect.StructField) string { return field } -func readStruct(s SQLTableNamer) (columns []string, values []interface{}) { +// if needsValues is false, we'll attempt to use the cache and won't return any values +func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []interface{}) { + typ := fmt.Sprintf("%T", s) + if cached, ok := structReadCache[typ]; !needsValues && ok { + return cached, nil + } v := reflect.ValueOf(s) t := reflect.TypeOf(s) k := t.Kind() @@ -94,11 +100,12 @@ func readStruct(s SQLTableNamer) (columns []string, values []interface{}) { columns = append(columns, field) values = append(values, value) } + structReadCache[typ] = columns return } func Columns(s SQLTableNamer) ColumnList { - columns, _ := readStruct(s) + columns, _ := readStruct(s, false) return columns } @@ -112,7 +119,7 @@ func Column(s SQLTableNamer, property string) string { } func ColumnValues(s SQLTableNamer) []interface{} { - _, values := readStruct(s) + _, values := readStruct(s, true) return values } From e46384666df2fb49979a6a98895a7cbd1670eb77 Mon Sep 17 00:00:00 2001 From: Paddy Date: Fri, 15 Jul 2016 09:48:02 -0700 Subject: [PATCH 08/21] Fix column output, make our caching parallel-safe. Make all our tests run in parallel. Add a RWMutex to our struct reflection cache. This removes annoying race conditions we had in our caching, without a noticeable performance impact. I also added a column cache, so the Column function could only use reflection the first time it saw a type/property combo, but there was no measurable performance benefit, so I dropped it. I changed readStruct to only populate the values return variable if the needsValues argument is set to true. This has a probably immeasurable performance impact, but more importantly makes behaviour consistent, regardless of whether values are retrieved from the cache or from reflect directly. I dropped the getColumn helper, which Column called through to, as Column was the only thing calling it. So now it's just Column. I also fixed the bug that had us returning the bare column name, instead of the absolute column name, as the rest of the library does. This required an update to our sqlTable tests. --- query_test.go | 4 ++++ reflect.go | 37 +++++++++++++++++++++---------------- reflect_test.go | 7 +++++++ sql_test.go | 9 +++++---- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/query_test.go b/query_test.go index 6256ca0..b232249 100644 --- a/query_test.go +++ b/query_test.go @@ -135,6 +135,7 @@ func init() { } func TestQueriesFromTable(t *testing.T) { + t.Parallel() for pos, test := range queryTests { t.Logf("Testing: %s", test.Query.sql) mysql, mErr := test.Query.MySQLString() @@ -155,6 +156,7 @@ func TestQueriesFromTable(t *testing.T) { } func TestErrWrongNumberArgs(t *testing.T) { + t.Parallel() q := New("?") q.args = append(q.args, 1, 2, 3) err := q.checkCounts() @@ -177,6 +179,7 @@ func TestErrWrongNumberArgs(t *testing.T) { } func TestRepeatedOrder(t *testing.T) { + t.Parallel() q := New("SELECT * FROM test_data") q.OrderBy("id") q.OrderBy("name") @@ -191,6 +194,7 @@ func TestRepeatedOrder(t *testing.T) { } func TestOffset(t *testing.T) { + t.Parallel() q := New("SELECT * FROM test_data") q.Offset(10).Flush(" ") mysql, err := q.MySQLString() diff --git a/reflect.go b/reflect.go index 814e31f..49a3295 100644 --- a/reflect.go +++ b/reflect.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strings" + "sync" "unicode" "unicode/utf8" ) @@ -14,6 +15,7 @@ const ( var ( structReadCache = map[string][]string{} + structReadMutex sync.RWMutex ) func validTag(s string) bool { @@ -70,9 +72,12 @@ func getFieldColumn(f reflect.StructField) string { // if needsValues is false, we'll attempt to use the cache and won't return any values func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []interface{}) { typ := fmt.Sprintf("%T", s) + structReadMutex.RLock() if cached, ok := structReadCache[typ]; !needsValues && ok { + structReadMutex.RUnlock() return cached, nil } + structReadMutex.RUnlock() v := reflect.ValueOf(s) t := reflect.TypeOf(s) k := t.Kind() @@ -94,13 +99,18 @@ func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []i continue } field = s.GetSQLTableName() + "." + field - - // Get the value of the field - value := v.Field(i).Interface() columns = append(columns, field) - values = append(values, value) + + if needsValues { + // Get the value of the field + value := v.Field(i).Interface() + values = append(values, value) + } } + + structReadMutex.Lock() structReadCache[typ] = columns + structReadMutex.Unlock() return } @@ -113,17 +123,6 @@ func Columns(s SQLTableNamer) ColumnList { // Property must correspond exactly to the name of the property in the type, or this function will // panic. func Column(s SQLTableNamer, property string) string { - // BUG(paddy): return the table name as part of the column name - // when we put in a cache for reflect, we should refactor this to use readStruct - return getColumn(s, property) -} - -func ColumnValues(s SQLTableNamer) []interface{} { - _, values := readStruct(s, true) - return values -} - -func getColumn(s interface{}, property string) string { t := reflect.TypeOf(s) k := t.Kind() for k == reflect.Interface || k == reflect.Ptr { @@ -137,7 +136,13 @@ func getColumn(s interface{}, property string) string { if !ok { panic("Field not found in type: " + property) } - return getFieldColumn(field) + column := s.GetSQLTableName() + "." + getFieldColumn(field) + return column +} + +func ColumnValues(s SQLTableNamer) []interface{} { + _, values := readStruct(s, true) + return values } type SQLTableNamer interface { diff --git a/reflect_test.go b/reflect_test.go index 69bdc0f..a45ee91 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -23,6 +23,7 @@ func (t testType2) GetSQLTableName() string { } func TestReflectedProperties(t *testing.T) { + t.Parallel() foo := testType{ myInt: 1, MyTaggedInt: 2, @@ -58,6 +59,7 @@ var tags = map[string]bool{ } func TestValidTag(t *testing.T) { + t.Parallel() for input, validity := range tags { if validTag(input) != validity { expectedValidity := "valid" @@ -82,6 +84,7 @@ var camelToSnake = map[string]string{ } func TestCamelToSnake(t *testing.T) { + t.Parallel() for input, expectedOutput := range camelToSnake { if expectedOutput != toSnake(input) { t.Errorf("Expected `%s` to be `%s`, was `%s`", input, expectedOutput, toSnake(input)) @@ -96,6 +99,7 @@ func (i invalidSqlFieldReflector) GetSQLTableName() string { } func TestInvalidFieldReflection(t *testing.T) { + t.Parallel() columns := Columns(invalidSqlFieldReflector("test")) values := ColumnValues(invalidSqlFieldReflector("test")) if len(columns) != 0 { @@ -107,6 +111,7 @@ func TestInvalidFieldReflection(t *testing.T) { } func TestInterfaceOrPointerFieldReflection(t *testing.T) { + t.Parallel() columns := Columns(&testType{}) if len(columns) != 2 { t.Errorf("Expected %d columns, but got %v", len(columns), columns) @@ -139,6 +144,7 @@ func TestInterfaceOrPointerFieldReflection(t *testing.T) { } func TestInvalidColumnTypes(t *testing.T) { + t.Parallel() defer func() { t.Log(recover()) }() @@ -147,6 +153,7 @@ func TestInvalidColumnTypes(t *testing.T) { } func TestOmittedColumn(t *testing.T) { + t.Parallel() columns := Columns(&testType{}) for _, column := range columns { if column == "omitted_column" { diff --git a/sql_test.go b/sql_test.go index 892a896..7d9cd81 100644 --- a/sql_test.go +++ b/sql_test.go @@ -25,12 +25,12 @@ func init() { postgres: "INSERT INTO test_data (test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified) VALUES ($1,$2,$3,$4,$5,$6);", } sqlTable[New("UPDATE "+Table(p)+" SET").Assign(p, "Title", p.Title).Assign(p, "Author", p.Author).Flush(", ").Where().Comparison(p, "ID", "=", p.ID).Flush(" ")] = queryResult{ - mysql: "UPDATE test_data SET title = ?, author_id = ? WHERE id = ?;", - postgres: "UPDATE test_data SET title = $1, author_id = $2 WHERE id = $3;", + mysql: "UPDATE test_data SET test_data.title = ?, test_data.author_id = ? WHERE test_data.id = ?;", + postgres: "UPDATE test_data SET test_data.title = $1, test_data.author_id = $2 WHERE test_data.id = $3;", } sqlTable[New("SELECT "+Columns(p).String()+" FROM "+Table(p)).Where().Expression(Column(p, "Created")+" > (SELECT "+Column(p, "Created")+" FROM "+Table(p)+" WHERE "+Column(p, "ID")+" = ?)", 123).Where().OrderByDesc(Column(p, "Created")).Limit(19).Flush(" ")] = queryResult{ - postgres: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = $1) ORDER BY created DESC LIMIT $2;", - mysql: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = ?) ORDER BY created DESC LIMIT ?;", + postgres: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE test_data.created > (SELECT test_data.created FROM test_data WHERE test_data.id = $1) ORDER BY test_data.created DESC LIMIT $2;", + mysql: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE test_data.created > (SELECT test_data.created FROM test_data WHERE test_data.id = ?) ORDER BY test_data.created DESC LIMIT ?;", } } @@ -42,6 +42,7 @@ var sqlTable = map[*Query]queryResult{ } func TestSQLTable(t *testing.T) { + t.Parallel() for query, expectation := range sqlTable { t.Logf(query.String()) mysql, err := query.MySQLString() From bf67b1530775db3737a7504a794a1bd45689100e Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 15:25:09 -0700 Subject: [PATCH 09/21] Add a README. Use @agocs' README as a basis, and update it for the new syntax, and clarify some important points. --- README.md | 354 +++++++++++------------------------------------------- 1 file changed, 68 insertions(+), 286 deletions(-) diff --git a/README.md b/README.md index c673256..a75c687 100644 --- a/README.md +++ b/README.md @@ -1,326 +1,108 @@ -[![Build Status](https://travis-ci.org/secondbit/pan.png)](https://travis-ci.org/secondbit/pan) - -# pan +# Importing pan import "github.com/DramaFever/pan" +# About pan +`pan` is an SQL query building and response unmarshalling library for Go. It is designed to be compatible with MySQL and PostgreSQL, but should be more or less agnostic. Please let us know if your favourite SQL flavour is not supported. +Pan is not designed to be an ORM, but it still eliminates much of the boilerplate code around writing queries and scanning over rows. -## Constants -``` go -const ( - MYSQL dbengine = iota - POSTGRES -) -``` -``` go -const TAG_NAME = "sql_column" // The tag that will be read - -``` - - -## func GetAbsoluteColumnName -``` go -func GetAbsoluteColumnName(s sqlTableNamer, property string) string -``` -GetAbsoluteColumnName returns the field name associated with the specified property in the passed value. -Property must correspond exactly to the name of the property in the type, or this function will -panic. - - -## func GetAbsoluteFields -``` go -func GetAbsoluteFields(s sqlTableNamer) (fields []interface{}, values []interface{}) -``` -GetAbsoluteFields returns a slice of the fields in the passed type, with their names -drawn from tags or inferred from the property name (which will be lower-cased with underscores, -e.g. CamelCase => camel_case) and a corresponding slice of interface{}s containing the values for -those properties. Fields will be surrounded in \` marks and prefixed with their table name, as -determined by the passed type's GetSQLTableName. The format will be \`table_name\`.\`field_name\`. - - -## func GetColumn -``` go -func GetColumn(s interface{}, property string) string -``` -GetColumn returns the field name associated with the specified property in the passed value. -Property must correspond exactly to the name of the property in the type, or this function will -panic. - - -## func GetM2MAbsoluteColumnName -``` go -func GetM2MAbsoluteColumnName(t sqlTableNamer, field string, t2 sqlTableNamer) string -``` -GetM2MAbsoluteColumnName returns the column name for the supplied field in a many-to-many relationship table, -including the table name. The field belongs to the first sqlTableNamer, the second sqlTableNamer is the other -table in the many-to-many relationship. - - -## func GetM2MColumnName -``` go -func GetM2MColumnName(t sqlTableNamer, field string) string -``` -GetM2MColumnName returns the column name for the supplied field in a many-to-many relationship table. - - -## func GetM2MFields -``` go -func GetM2MFields(t1 sqlTableNamer, field1 string, t2 sqlTableNamer, field2 string) (columns, values []interface{}) -``` -GetM2MFields returns a slice of the columns that should be in a table that maps the many-to-many relationship of -the types supplied, with their corresponding values. The field parameters specify the primary keys used in -the relationship table to map to that type. - +Pan’s design focuses on reducing repetition and hardcoded strings in your queries, but without limiting your ability to write any form of query you want. It is meant to be the smallest possible abstraction on top of SQL. -## func GetM2MQuotedColumnName -``` go -func GetM2MQuotedColumnName(t sqlTableNamer, field string) string -``` -GetM2MQuotedColumnName returns the column name for the supplied field in a many-to-many relationship table, -including the quote marks around the column name. +Docs can be found on [GoDoc.org](https://godoc.org/darlinggo.co/pan). +# Using pan -## func GetM2MQuotedFields -``` go -func GetM2MQuotedFields(t1 sqlTableNamer, field1 string, t2 sqlTableNamer, field2 string) (columns, values []interface{}) -``` -GetM2MQuotedFields wraps the fields returned by GetM2MFields in quotes. - - -## func GetM2MTableName -``` go -func GetM2MTableName(t1, t2 sqlTableNamer) string -``` -GetM2MTableName returns a consistent table name for a many-to-many relationship between two tables. No -matter what order the fields are passed in, the resulting table name will always be consistent. +Pan revolves around structs that fill the `SQLTableNamer` interface, by implementing the `GetSQLTableName() string` function, which just returns the name of the table that should store the data for that struct. +Let's say you have a `Person` in your code. -## func GetQuotedFields -``` go -func GetQuotedFields(s sqlTableNamer) (fields []interface{}, values []interface{}) -``` -GetQuotedFields returns a slice of the fields in the passed type, with their names -drawn from tags or inferred from the property name (which will be lower-cased with underscores, -e.g. CamelCase => camel_case) and a corresponding slice of interface{}s containing the values for -those properties. Fields will be surrounding in ` marks. - - -## func GetTableName -``` go -func GetTableName(t sqlTableNamer) string -``` -GetTableName returns the table name for any type that implements the `GetSQLTableName() string` -method signature. The returned string will be used as the name of the table to store the data -for all instances of the type. - - -## func GetUnquotedColumn -``` go -func GetUnquotedColumn(s interface{}, property string) string -``` - -## func QueryList -``` go -func QueryList(fields []interface{}) string -``` -QueryList joins the passed fields into a string that can be used when selecting the fields to return -or specifying fields in an update or insert. - - -## func Unmarshal -``` go -func Unmarshal(s Scannable, dst interface{}) error -``` - -## func VariableList -``` go -func VariableList(num int) string -``` -VariableList returns a list of `num` variable placeholders for use in SQL queries involving slices -and arrays. - - - -## type Query -``` go -type Query struct { - SQL string - Args []interface{} - Expressions []string - IncludesWhere bool - IncludesOrder bool - IncludesLimit bool - Engine dbengine +```go +type Person struct { + ID int `sql_column:"person_id"` + FName string `sql_column:"fname"` + LName string `sql_column:"lname"` + Age int } ``` -Query contains the data needed to perform a single SQL query. - - - - - - - - - -### func New -``` go -func New(engine dbengine, query string) *Query -``` -New creates a new Query object. The passed engine is used to format variables. The passed string is used to prefix the query. +And you have a corresponding `Person` table: - - -### func (\*Query) FlushExpressions -``` go -func (q *Query) FlushExpressions(join string) *Query -``` -FlushExpressions joins the Query's Expressions with the join string, then concatenates them -to the Query's SQL. It then resets the Query's Expressions. This permits Expressions to be joined -by different strings within a single Query. - - - -### func (\*Query) Generate -``` go -func (q *Query) Generate(join string) string -``` -Generate creates a string from the Query, joining its SQL property and its Expressions. Expressions are joined -using the join string supplied. - - - -### func (\*Query) Include -``` go -func (q *Query) Include(key string, values ...interface{}) *Query -``` -Include adds the supplied key (which should be an expression) to the Query's Expressions and the value -to the Query's Args. - - - -### func (\*Query) IncludeIfNotEmpty -``` go -func (q *Query) IncludeIfNotEmpty(key string, value interface{}) *Query ``` -IncludeIfNotEmpty adds the supplied key (which should be an expression) to the Query's Expressions if -and only if the value parameter is not the empty value for its type. If the key is added to the Query's -Expressions, the value is added to the Query's Args. - - - -### func (\*Query) IncludeIfNotNil -``` go -func (q *Query) IncludeIfNotNil(key string, value interface{}) *Query -``` -IncludeIfNotNil adds the supplied key (which should be an expression) to the Query's Expressions if -and only if the value parameter is not a nil value. If the key is added to the Query's Expressions, the -value is added to the Query's Args. - - - -### func (\*Query) IncludeLimit -``` go -func (q *Query) IncludeLimit(limit int64) *Query ++-----------+-------------+------+-----+---------+-------+ +| Field | Type | Null | Key | Default | Extra | ++-----------+-------------+------+-----+---------+-------+ +| person_id | int | NO | | NULL | | +| fname | varchar(20) | NO | | '' | | +| lname | varchar(20) | NO | | '' | | +| age | int | NO | | 0 | | ++-----------+-------------+------+-----+---------+-------+ ``` -IncludeLimit includes the LIMIT clause if the LIMIT clause has not already been included in the Query. -This cannot detect LIMIT clauses that are manually added to the Query's SQL; it only tracks IncludeLimit(). -The passed int is used as the limit in the resulting query. +> **Note**: Unless you're using sql.NullString or equivalent, it's not recommended to allow `NULL` in your data. It may cause you trouble when unmarshaling. +To use that `Person` type with pan, you need it to fill the `SQLTableNamer` interface, letting pan know to use the `person` table in your database: -### func (\*Query) IncludeOffset -``` go -func (q *Query) IncludeOffset(offset int64) *Query +```go +func (p Person)GetSQLTableName()string{ + return "person" +} ``` +## Creating a query -### func (\*Query) IncludeOrder -``` go -func (q *Query) IncludeOrder(orderClause string) *Query +```go +// selects all rows +var p Person +query := pan.New(pan.MYSQL, "SELECT "+pan.Columns(p).String()+" FROM "+pan.Table(p)) ``` -IncludeOrder includes the ORDER BY clause if the ORDER BY clause has not already been included in the Query. -This cannot detect ORDER BY clauses that are manually added to the Query's SQL; it only tracks IncludeOrder(). -The passed string is used as the expression to order by. - +or -### func (\*Query) IncludeWhere -``` go -func (q *Query) IncludeWhere() *Query +```go +// selects one person +var p Person +query := pan.New(pan.MYSQL, "SELECT "+pan.Columns(p).String()+" FROM "+pan.Table(p)).Where() +query.Comparison(p, "ID", "=", 1) +query.Flush(" ") ``` -IncludeWhere includes the WHERE clause if the WHERE clause has not already been included in the Query. -This cannot detect WHERE clauses that are manually added to the Query's SQL; it only tracks IncludeWhere(). +That `Flush` command is important: pan works by creating a buffer of strings, and then joining them by some separator character. Flush takes the separator character (in this case, a space) and uses it to join all the buffered strings (in this case, the `WHERE` statement and the `person_id = ?` statement), and then adds the result to its query. +> It's safe to call `Flush` even if there are no buffered strings, so a good practice is to just call `Flush` after the entire query is built, just to make sure you don't leave anything buffered. -### func (\*Query) InnerJoin -``` go -func (q *Query) InnerJoin(table, expression string) *Query -``` - - -### func (\*Query) String -``` go -func (q *Query) String() string -``` -String fulfills the String interface for Queries, and returns the generated SQL query after every instance of ? -has been replaced with a counter prefixed with $ (e.g., $1, $2, $3). There is no support for using ?, quoted or not, -within an expression. All instances of the ? character that are not meant to be substitutions should be as arguments -in prepared statements. - +The `pan.Columns()` function returns the column names that a struct's properties correspond to. `pan.Columns().String()` joins them into a list of columns that can be passed right to the `SELECT` expression, making it easy to support reading only the columns you need, maintaining forward compatibility—your code will never choke on unexpected columns being added. +## Executing the query and reading results -## type Scannable -``` go -type Scannable interface { - Scan(dest ...interface{}) error +```go +mysql, err := query.MySQLString() // could also be PostgreSQLString +if err != nil { + // handle the error } -``` - - - - - - - - - - -## type WrongNumberArgsError -``` go -type WrongNumberArgsError struct { - NumExpected int - NumFound int +rows, err := db.Query(mysql, query.Args...) +if err != nil { + // handle the error +} +var people []Person +for rows.Next() { + var p Person + err := pan.Unmarshal(rows, &p) // put the results into the struct + if err != nil { + // handle the error + } + people = append(people, p) } ``` -WrongNumberArgsError is thrown when a Query is evaluated whose Args does not match its Expressions. - - - - - - - - - - - -### func (WrongNumberArgsError) Error -``` go -func (e WrongNumberArgsError) Error() string -``` -Error fulfills the error interface, returning the expected number of arguments and the number supplied. - - - +## A note about time +If you're going to be reading the MySQL `time` type and you plan to parse that into a Go `time`, you must include `&parseTime=true` in your DSN. +## How Struct Properties Turn Into Column Names +There are a couple rules about how struct properties become column names. First, only exported struct properties are used; unexported properties are ignored. +By default, a struct property's name is snake-cased, and that is used as the column name. For example, `Name` would become `name`, and `MyInt` would become `my_int`. -- - - -Generated by [godoc2md](http://godoc.org/github.com/davecheney/godoc2md) +If you want more control or want to make columns explicit, the `sql_column` struct tag can be used to override this behaviour. From 31b2b6bc94b8a1d8d5f5c45144cf6efa5ad0c0b9 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 15:30:03 -0700 Subject: [PATCH 10/21] Minor readme tweaks. Remove the section on parseTime, as that's not a pan restriction (it depends on the driver being used), fix a typo, update wording on struct property mapping. --- README.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index a75c687..0f9b3c0 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ And you have a corresponding `Person` table: To use that `Person` type with pan, you need it to fill the `SQLTableNamer` interface, letting pan know to use the `person` table in your database: ```go -func (p Person)GetSQLTableName()string{ +func (p Person) GetSQLTableName()string{ return "person" } ``` @@ -60,7 +60,7 @@ query := pan.New(pan.MYSQL, "SELECT "+pan.Columns(p).String()+" FROM "+pan.Table or ```go -// selects one person +// selects one row var p Person query := pan.New(pan.MYSQL, "SELECT "+pan.Columns(p).String()+" FROM "+pan.Table(p)).Where() query.Comparison(p, "ID", "=", 1) @@ -95,13 +95,9 @@ for rows.Next() { } ``` -## A note about time +## How struct properties map to columns -If you're going to be reading the MySQL `time` type and you plan to parse that into a Go `time`, you must include `&parseTime=true` in your DSN. - -## How Struct Properties Turn Into Column Names - -There are a couple rules about how struct properties become column names. First, only exported struct properties are used; unexported properties are ignored. +There are a couple rules about how struct properties map to column names. First, only exported struct properties are used; unexported properties are ignored. By default, a struct property's name is snake-cased, and that is used as the column name. For example, `Name` would become `name`, and `MyInt` would become `my_int`. From 306036f2b4ec0144c131fa179b753c0d11521b48 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 15:31:20 -0700 Subject: [PATCH 11/21] Typos will be the death of me. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0f9b3c0..a57d62e 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ And you have a corresponding `Person` table: To use that `Person` type with pan, you need it to fill the `SQLTableNamer` interface, letting pan know to use the `person` table in your database: ```go -func (p Person) GetSQLTableName()string{ +func (p Person) GetSQLTableName() string{ return "person" } ``` From e3ddc3db1b4d9d3b31a7fe5b99ee4971dca9f130 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 15:32:06 -0700 Subject: [PATCH 12/21] I wish I could gofmt this README. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a57d62e..cfc4e67 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ And you have a corresponding `Person` table: To use that `Person` type with pan, you need it to fill the `SQLTableNamer` interface, letting pan know to use the `person` table in your database: ```go -func (p Person) GetSQLTableName() string{ +func (p Person) GetSQLTableName() string { return "person" } ``` From 30e981f84a0752ec4fc9d23bf9fea5a0e365b885 Mon Sep 17 00:00:00 2001 From: Paddy Foran Date: Sat, 16 Jul 2016 16:14:10 -0700 Subject: [PATCH 13/21] Use sql.Rows' Columns method when unmarshaling. Take advantage of having the Column names available to us to do a better job of matching columns to properties. Drop docstrings, which are all out of date and need rewritten. Rename VariableList to Placeholders, which is really what it's generating. Update our Unmarshal test to use Postgres, not sqlite, which doesn't support table.column_name syntax, and was giving me trouble. Eventually, we should find a fix for this? --- query.go | 8 ++---- reflect.go | 75 ++++++++++++++++++++++++++++++++++--------------- reflect_test.go | 43 +++++++++++++++++----------- sql_test.go | 2 +- 4 files changed, 81 insertions(+), 47 deletions(-) diff --git a/query.go b/query.go index ae39093..9c8a9ca 100644 --- a/query.go +++ b/query.go @@ -6,7 +6,6 @@ import ( "strings" ) -// Query contains the data needed to perform a single SQL query. type Query struct { sql string args []interface{} @@ -21,7 +20,6 @@ func (c ColumnList) String() string { return strings.Join(c, ", ") } -// New creates a new Query object. The passed engine is used to format variables. The passed string is used to prefix the query. func New(query string) *Query { return &Query{ sql: query, @@ -38,18 +36,16 @@ func Insert(obj SQLTableNamer, values ...SQLTableNamer) *Query { for _, v := range inserts { columnValues := ColumnValues(v) - query.Expression("("+VariableList(len(columnValues))+")", columnValues...) + query.Expression("("+Placeholders(len(columnValues))+")", columnValues...) } return query.Flush(" ") } -// ErrWrongNumberArgs is thrown when a Query is evaluated whose Args does not match its Expressions. type ErrWrongNumberArgs struct { NumExpected int NumFound int } -// Error fulfills the error interface, returning the expected number of arguments and the number supplied. func (e ErrWrongNumberArgs) Error() string { return fmt.Sprintf("Expected %d arguments, got %d.", e.NumExpected, e.NumFound) } @@ -134,7 +130,7 @@ func (q *Query) Comparison(obj SQLTableNamer, property, operator string, value i } func (q *Query) In(obj SQLTableNamer, property string, values ...interface{}) *Query { - return q.Expression(Column(obj, property)+" IN("+VariableList(len(values))+")", values...) + return q.Expression(Column(obj, property)+" IN("+Placeholders(len(values))+")", values...) } func (q *Query) Assign(obj SQLTableNamer, property string, value interface{}) *Query { diff --git a/reflect.go b/reflect.go index 49a3295..894d076 100644 --- a/reflect.go +++ b/reflect.go @@ -3,6 +3,7 @@ package pan import ( "fmt" "reflect" + "sort" "strings" "sync" "unicode" @@ -69,7 +70,7 @@ func getFieldColumn(f reflect.StructField) string { return field } -// if needsValues is false, we'll attempt to use the cache and won't return any values +// if needsValues is false, we'll attempt to use the cache and `values` will be nil func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []interface{}) { typ := fmt.Sprintf("%T", s) structReadMutex.RLock() @@ -119,9 +120,6 @@ func Columns(s SQLTableNamer) ColumnList { return columns } -// GetColumn returns the field name associated with the specified property in the passed value. -// Property must correspond exactly to the name of the property in the type, or this function will -// panic. func Column(s SQLTableNamer, property string) string { t := reflect.TypeOf(s) k := t.Kind() @@ -149,16 +147,11 @@ type SQLTableNamer interface { GetSQLTableName() string } -// GetTableName returns the table name for any type that implements the `GetSQLTableName() string` -// method signature. The returned string will be used as the name of the table to store the data -// for all instances of the type. func Table(t SQLTableNamer) string { return t.GetSQLTableName() } -// VariableList returns a list of `num` variable placeholders for use in SQL queries involving slices -// and arrays. -func VariableList(num int) string { +func Placeholders(num int) string { placeholders := make([]string, num) for pos := 0; pos < num; pos++ { placeholders[pos] = "?" @@ -166,18 +159,46 @@ func VariableList(num int) string { return strings.Join(placeholders, ",") } -// QueryList joins the passed fields into a string that can be used when selecting the fields to return -// or specifying fields in an update or insert. -func QueryList(fields []interface{}) string { - strs := make([]string, len(fields)) - for pos, field := range fields { - strs[pos] = field.(string) - } - return strings.Join(strs, ", ") +type Scannable interface { + Scan(dst ...interface{}) error + Columns() ([]string, error) } -type Scannable interface { - Scan(dest ...interface{}) error +type pointer struct { + addr interface{} + column string + sortOrder int +} + +type pointers []pointer + +func (p pointers) Len() int { return len(p) } + +func (p pointers) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + +func (p pointers) Less(i, j int) bool { return p[i].sortOrder < p[j].sortOrder } + +func getColumnAddrs(s Scannable, in []pointer) ([]interface{}, error) { + columns, err := s.Columns() + if err != nil { + return nil, err + } + var results pointers + for _, pointer := range in { + for pos, column := range columns { + if column == pointer.column { + pointer.sortOrder = pos + results = append(results, pointer) + break + } + } + } + sort.Sort(results) + i := make([]interface{}, 0, len(results)) + for _, res := range results { + i = append(i, res.addr) + } + return i, nil } func Unmarshal(s Scannable, dst interface{}) error { @@ -192,7 +213,7 @@ func Unmarshal(s Scannable, dst interface{}) error { if k != reflect.Struct { return s.Scan(dst) } - pointers := []interface{}{} + props := []pointer{} for i := 0; i < t.NumField(); i++ { if t.Field(i).PkgPath != "" { // skip unexported fields @@ -204,7 +225,15 @@ func Unmarshal(s Scannable, dst interface{}) error { } // Get the value of the field - pointers = append(pointers, v.Field(i).Addr().Interface()) + props = append(props, pointer{ + addr: v.Field(i).Addr().Interface(), + column: field, + }) + } + + addrs, err := getColumnAddrs(s, props) + if err != nil { + return err } - return s.Scan(pointers...) + return s.Scan(addrs...) } diff --git a/reflect_test.go b/reflect_test.go index a45ee91..a44f007 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -1,6 +1,12 @@ package pan -import "testing" +import ( + "database/sql" + "os" + "testing" + + _ "github.com/lib/pq" +) type testType struct { myInt int @@ -162,11 +168,11 @@ func TestOmittedColumn(t *testing.T) { } } -/* func TestUnmarshal(t *testing.T) { - os.Remove("./test.db") - - db, err := sql.Open("sqlite3", "./test.db") + if os.Getenv("PG_DSN") == "" { + t.Skip() + } + db, err := sql.Open("postgres", os.Getenv("PG_DSN")) if err != nil { t.Error(err) } @@ -184,27 +190,30 @@ func TestUnmarshal(t *testing.T) { if err != nil { t.Error(err) } - fields, values := GetQuotedFields(dummy) - q := New(MYSQL, "INSERT INTO "+GetTableName(dummy)) - q.Include("(" + QueryList(fields) + ")") - q.Include("VALUES") - q.Include("("+VariableList(len(values))+")", values...) - q.FlushExpressions(" ") - _, err = db.Exec(q.String(), q.Args...) + q := Insert(dummy) + mysql, err := q.MySQLString() if err != nil { t.Error(err) } - fields, _ = GetQuotedFieldsAndExpressions(dummy) - row := db.QueryRow("SELECT " + QueryList(fields) + " FROM test_types;") - err = Unmarshal(row, &expectation) + _, err = db.Exec(mysql, q.Args()...) if err != nil { + t.Log(q.String()) t.Error(err) } + rows, err := db.Query("SELECT " + Columns(dummy).String() + " FROM test_types;") + if err != nil { + t.Error(err) + } + for rows.Next() { + err = Unmarshal(rows, &expectation) + if err != nil { + t.Error(err) + } + } if expectation.MyTaggedInt != dummy.MyTaggedInt { t.Errorf("Expected MyTaggedInt to be %d, was %d.", dummy.MyTaggedInt, expectation.MyTaggedInt) } if expectation.MyString != dummy.MyString { t.Errorf("Expected MyString to be %s, was %s.", dummy.MyString, expectation.MyString) } - os.Remove("./test.db") -}*/ +} diff --git a/sql_test.go b/sql_test.go index 7d9cd81..f04e368 100644 --- a/sql_test.go +++ b/sql_test.go @@ -35,7 +35,7 @@ func init() { } var sqlTable = map[*Query]queryResult{ - New("INSERT INTO "+Table(testPost{})).Expression("("+VariableList(4)+")", "a", "b", "c", "d").Expression("VALUES").Expression("("+VariableList(4)+")", 0, 1, 2, 3).Flush(" "): { + New("INSERT INTO "+Table(testPost{})).Expression("("+Placeholders(4)+")", "a", "b", "c", "d").Expression("VALUES").Expression("("+Placeholders(4)+")", 0, 1, 2, 3).Flush(" "): { mysql: "INSERT INTO test_data (?,?,?,?) VALUES (?,?,?,?);", postgres: "INSERT INTO test_data ($1,$2,$3,$4) VALUES ($5,$6,$7,$8);", }, From 81cfe236f23b95983ec35565c9761287e0e56a7f Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 16:49:58 -0700 Subject: [PATCH 14/21] Introduce column flags. Add a Flag type that will let callers specify whether they want the absolute column name (table.column), double-quoted ("column"), ticked (`column`), or a combination thereof. This is exposed only on the Columns and Column methods; convenience methods built on top only ever use the bare column name. People can deal with it. Stop exporting the TAG_NAME constant, there's no reason to. `Placeholders` now generates placholders with a space after the comma. We're back to using SQLite for the tests, now that it works again. --- README.md | 15 +++++++++++++ query.go | 8 +++++++ reflect.go | 59 +++++++++++++++++++++++++++++++++++++++---------- reflect_test.go | 14 ++++++------ sql_test.go | 16 +++++++------- 5 files changed, 85 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index cfc4e67..47240e6 100644 --- a/README.md +++ b/README.md @@ -102,3 +102,18 @@ There are a couple rules about how struct properties map to column names. First, By default, a struct property's name is snake-cased, and that is used as the column name. For example, `Name` would become `name`, and `MyInt` would become `my_int`. If you want more control or want to make columns explicit, the `sql_column` struct tag can be used to override this behaviour. + +## Column flags + +Sometimes, you need more than the base column name; you may need the full name (`table.column`) or you may be using special characters/need to quote the column name (`"column"` for Postgres, `\`column`\` for MySQL). To support these use cases, the `Column` and `Columns` functions take a variable number of flags (including none): + +```go +Columns() // returns column format +Columns(FlagFull) // returns table.column format +Columns(FlagDoubleQuoted) // returns "column" format +Columns(FlagTicked) // returns `column` format +Columns(FlagFull, FlagDoubleQuoted) // returns "table"."column" format +Columns(FlagFull, FlagTicked) // returns `table`.`column` format +``` + +This behaviour is not exposed through the convenience functions built on top of `Column` and `Columns`; you'll need to use `Expression` to rebuild them by hand. Usually, this can be done simply; look at the source code for those convenience functions for examples. diff --git a/query.go b/query.go index 9c8a9ca..2dd4110 100644 --- a/query.go +++ b/query.go @@ -6,6 +6,12 @@ import ( "strings" ) +const ( + FlagFull Flag = iota + FlagTicked + FlagDoubleQuoted +) + type Query struct { sql string args []interface{} @@ -20,6 +26,8 @@ func (c ColumnList) String() string { return strings.Join(c, ", ") } +type Flag int + func New(query string) *Query { return &Query{ sql: query, diff --git a/reflect.go b/reflect.go index 894d076..617c938 100644 --- a/reflect.go +++ b/reflect.go @@ -11,7 +11,7 @@ import ( ) const ( - TAG_NAME = "sql_column" // The tag that will be read + tagName = "sql_column" // The tag that will be read ) var ( @@ -60,7 +60,7 @@ func toSnake(s string) string { func getFieldColumn(f reflect.StructField) string { // Get the SQL column name, from the tag or infer it - field := f.Tag.Get(TAG_NAME) + field := f.Tag.Get(tagName) if field == "-" { return "" } @@ -70,13 +70,49 @@ func getFieldColumn(f reflect.StructField) string { return field } +func hasFlags(list []Flag, passed ...Flag) bool { + for _, candidate := range passed { + var found bool + for _, f := range list { + if f == candidate { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +func decorateColumns(columns []string, table string, flags ...Flag) []string { + results := make([]string, 0, len(columns)) + for _, name := range columns { + if hasFlags(flags, FlagTicked) { + name = "`" + name + "`" + } else if hasFlags(flags, FlagDoubleQuoted) { + name = `"` + name + `"` + } + if hasFlags(flags, FlagFull, FlagTicked) { + name = "`" + table + "`." + name + } else if hasFlags(flags, FlagFull, FlagDoubleQuoted) { + name = `"` + table + `".` + name + } else if hasFlags(flags, FlagFull) { + name = table + "." + name + } + results = append(results, name) + } + return results +} + // if needsValues is false, we'll attempt to use the cache and `values` will be nil -func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []interface{}) { +func readStruct(s SQLTableNamer, needsValues bool, flags ...Flag) (columns []string, values []interface{}) { typ := fmt.Sprintf("%T", s) structReadMutex.RLock() if cached, ok := structReadCache[typ]; !needsValues && ok { structReadMutex.RUnlock() - return cached, nil + return decorateColumns(cached, s.GetSQLTableName(), flags...), nil } structReadMutex.RUnlock() v := reflect.ValueOf(s) @@ -99,7 +135,6 @@ func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []i if field == "" { continue } - field = s.GetSQLTableName() + "." + field columns = append(columns, field) if needsValues { @@ -112,15 +147,15 @@ func readStruct(s SQLTableNamer, needsValues bool) (columns []string, values []i structReadMutex.Lock() structReadCache[typ] = columns structReadMutex.Unlock() - return + return decorateColumns(columns, s.GetSQLTableName(), flags...), values } -func Columns(s SQLTableNamer) ColumnList { - columns, _ := readStruct(s, false) +func Columns(s SQLTableNamer, flags ...Flag) ColumnList { + columns, _ := readStruct(s, false, flags...) return columns } -func Column(s SQLTableNamer, property string) string { +func Column(s SQLTableNamer, property string, flags ...Flag) string { t := reflect.TypeOf(s) k := t.Kind() for k == reflect.Interface || k == reflect.Ptr { @@ -134,8 +169,8 @@ func Column(s SQLTableNamer, property string) string { if !ok { panic("Field not found in type: " + property) } - column := s.GetSQLTableName() + "." + getFieldColumn(field) - return column + columns := decorateColumns([]string{getFieldColumn(field)}, s.GetSQLTableName(), flags...) + return columns[0] } func ColumnValues(s SQLTableNamer) []interface{} { @@ -156,7 +191,7 @@ func Placeholders(num int) string { for pos := 0; pos < num; pos++ { placeholders[pos] = "?" } - return strings.Join(placeholders, ",") + return strings.Join(placeholders, ", ") } type Scannable interface { diff --git a/reflect_test.go b/reflect_test.go index a44f007..7bd69b9 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" ) type testType struct { @@ -36,7 +36,7 @@ func TestReflectedProperties(t *testing.T) { MyString: "hello", myTaggedString: "world", } - columns := Columns(foo) + columns := Columns(foo, FlagFull) if len(columns) != 2 { t.Errorf("Columns should have length %d, has length %d", 2, len(columns)) } @@ -46,7 +46,7 @@ func TestReflectedProperties(t *testing.T) { } for pos, column := range columns { if column != "test_types.tagged_int" && column != "test_types.my_string" { - t.Errorf("Unknown column found: %v'", column) + t.Errorf("Unknown column found: %v", column) } if column == "test_types.tagged_int" && values[pos].(int) != 2 { t.Errorf("Expected tagged_int to be %d, got %v", 2, values[pos]) @@ -169,10 +169,9 @@ func TestOmittedColumn(t *testing.T) { } func TestUnmarshal(t *testing.T) { - if os.Getenv("PG_DSN") == "" { - t.Skip() - } - db, err := sql.Open("postgres", os.Getenv("PG_DSN")) + os.Remove("./test.db") + + db, err := sql.Open("sqlite3", "./test.db") if err != nil { t.Error(err) } @@ -216,4 +215,5 @@ func TestUnmarshal(t *testing.T) { if expectation.MyString != dummy.MyString { t.Errorf("Expected MyString to be %s, was %s.", dummy.MyString, expectation.MyString) } + os.Remove("./test.db") } diff --git a/sql_test.go b/sql_test.go index f04e368..878ecfe 100644 --- a/sql_test.go +++ b/sql_test.go @@ -21,23 +21,23 @@ func (t testPost) GetSQLTableName() string { func init() { p := testPost{123, "my post", 1, "this is a test post", time.Now(), nil} sqlTable[Insert(p)] = queryResult{ - mysql: "INSERT INTO test_data (test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified) VALUES (?,?,?,?,?,?);", - postgres: "INSERT INTO test_data (test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified) VALUES ($1,$2,$3,$4,$5,$6);", + mysql: "INSERT INTO test_data (id, title, author_id, body, created, modified) VALUES (?, ?, ?, ?, ?, ?);", + postgres: "INSERT INTO test_data (id, title, author_id, body, created, modified) VALUES ($1, $2, $3, $4, $5, $6);", } sqlTable[New("UPDATE "+Table(p)+" SET").Assign(p, "Title", p.Title).Assign(p, "Author", p.Author).Flush(", ").Where().Comparison(p, "ID", "=", p.ID).Flush(" ")] = queryResult{ - mysql: "UPDATE test_data SET test_data.title = ?, test_data.author_id = ? WHERE test_data.id = ?;", - postgres: "UPDATE test_data SET test_data.title = $1, test_data.author_id = $2 WHERE test_data.id = $3;", + mysql: "UPDATE test_data SET title = ?, author_id = ? WHERE id = ?;", + postgres: "UPDATE test_data SET title = $1, author_id = $2 WHERE id = $3;", } sqlTable[New("SELECT "+Columns(p).String()+" FROM "+Table(p)).Where().Expression(Column(p, "Created")+" > (SELECT "+Column(p, "Created")+" FROM "+Table(p)+" WHERE "+Column(p, "ID")+" = ?)", 123).Where().OrderByDesc(Column(p, "Created")).Limit(19).Flush(" ")] = queryResult{ - postgres: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE test_data.created > (SELECT test_data.created FROM test_data WHERE test_data.id = $1) ORDER BY test_data.created DESC LIMIT $2;", - mysql: "SELECT test_data.id, test_data.title, test_data.author_id, test_data.body, test_data.created, test_data.modified FROM test_data WHERE test_data.created > (SELECT test_data.created FROM test_data WHERE test_data.id = ?) ORDER BY test_data.created DESC LIMIT ?;", + postgres: "SELECT id, title, author_id, body, created, modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = $1) ORDER BY created DESC LIMIT $2;", + mysql: "SELECT id, title, author_id, body, created, modified FROM test_data WHERE created > (SELECT created FROM test_data WHERE id = ?) ORDER BY created DESC LIMIT ?;", } } var sqlTable = map[*Query]queryResult{ New("INSERT INTO "+Table(testPost{})).Expression("("+Placeholders(4)+")", "a", "b", "c", "d").Expression("VALUES").Expression("("+Placeholders(4)+")", 0, 1, 2, 3).Flush(" "): { - mysql: "INSERT INTO test_data (?,?,?,?) VALUES (?,?,?,?);", - postgres: "INSERT INTO test_data ($1,$2,$3,$4) VALUES ($5,$6,$7,$8);", + mysql: "INSERT INTO test_data (?, ?, ?, ?) VALUES (?, ?, ?, ?);", + postgres: "INSERT INTO test_data ($1, $2, $3, $4) VALUES ($5, $6, $7, $8);", }, } From b8a66c8f3c75638f44e1bc90653adc1e6a45937d Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 16:59:25 -0700 Subject: [PATCH 15/21] Test against most recent versions of Go. Add more versions of Go to .travis.yml so we can test against more versions of Go. --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index f6e4872..48915e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,4 +2,8 @@ language: go go: - 1.2 + - 1.3 + - 1.4 + - 1.5 + - 1.6 - tip From b52b93afa8b64d86b9f4a49c4fe71bb502ece199 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 17:02:09 -0700 Subject: [PATCH 16/21] Remove go1.2 tests. The go-sqlite3 library we're using in our tests has no support for Go 1.2, apparently. The tests keep breaking on it. So let's not test against that, and support 1.3 and above. --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 48915e7..bde823d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - 1.2 - 1.3 - 1.4 - 1.5 From 0b7ab635fed5ac8bc9458d0af7a262b0c509e81c Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 17:09:43 -0700 Subject: [PATCH 17/21] Allow adding additional pointers to Unmarshal. Allow people to add additional pointers to Unmarshal, so they can use the struct columns as the basis for a query, but still add things on to them. This obviates the need for the expression tag. --- reflect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/reflect.go b/reflect.go index 617c938..bc6991a 100644 --- a/reflect.go +++ b/reflect.go @@ -236,7 +236,7 @@ func getColumnAddrs(s Scannable, in []pointer) ([]interface{}, error) { return i, nil } -func Unmarshal(s Scannable, dst interface{}) error { +func Unmarshal(s Scannable, dst interface{}, additional ...interface{}) error { t := reflect.TypeOf(dst) v := reflect.ValueOf(dst) k := t.Kind() @@ -270,5 +270,6 @@ func Unmarshal(s Scannable, dst interface{}) error { if err != nil { return err } + addrs = append(addrs, additional...) return s.Scan(addrs...) } From 5457ff8007a8ef57df0b85c89f330d52ae389607 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 17:19:45 -0700 Subject: [PATCH 18/21] Add note about mailing list, add CONTRIBUTORS. Add a note to our README linking to the mailing list, and add a CONTRIBUTORS file. --- CONTRIBUTORS | 6 ++++++ README.md | 2 ++ 2 files changed, 8 insertions(+) create mode 100644 CONTRIBUTORS diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 0000000..30928e7 --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,6 @@ +# These contributors are the actual individuals that contributed code +# to pan. They are not necessarily copyright holders; the copyright +# holders are listed in the AUTHORS file. + +Paddy +Chris Agocs diff --git a/README.md b/README.md index 47240e6..1c0d4f3 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ Pan’s design focuses on reducing repetition and hardcoded strings in your quer Docs can be found on [GoDoc.org](https://godoc.org/darlinggo.co/pan). +If you're using pan, we encourage you to join the [pan mailing list](https://groups.google.com/a/darlinggo.co/group/pan), which will be our main mode of communication. + # Using pan Pan revolves around structs that fill the `SQLTableNamer` interface, by implementing the `GetSQLTableName() string` function, which just returns the name of the table that should store the data for that struct. From 561d17e0f776b3b4b6c615611528b69b980f3f46 Mon Sep 17 00:00:00 2001 From: Paddy Date: Sat, 16 Jul 2016 17:21:36 -0700 Subject: [PATCH 19/21] Update import path in README. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1c0d4f3..fb583d0 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Importing pan - import "github.com/DramaFever/pan" + import "darlinggo.co/pan" # About pan From a3a749a996e7f200db9de356f1b73bed067b5c29 Mon Sep 17 00:00:00 2001 From: Paddy Date: Mon, 25 Jul 2016 21:37:39 -0700 Subject: [PATCH 20/21] Add docs, return ErrNeedsFlush. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add docstrings so golint is happy. Explicitly ignore the output of our String function when benchmarking, to make go vet happy. Add an ErrNeedsFlush error, which is returned when a Query is used but has dangling expressions that weren’t properly flushed. Add a test that ErrNeedsFlush is returned, as expected. --- query.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++++ query_test.go | 12 ++++++- reflect.go | 22 +++++++++++++ reflect_test.go | 8 ++--- 4 files changed, 124 insertions(+), 5 deletions(-) diff --git a/query.go b/query.go index 2dd4110..cfd8be0 100644 --- a/query.go +++ b/query.go @@ -1,17 +1,36 @@ package pan import ( + "errors" "fmt" "strconv" "strings" ) const ( + // FlagFull returns columns in their absolute table.column format. FlagFull Flag = iota + // FlagTicked returns columns using ticks to quote the column name, like `column`. FlagTicked + // FlagDoubleQuoted returns columns using double quotes to quote the column name, like "column". FlagDoubleQuoted ) +var ( + // ErrNeedsFlush is returned when a Query is used while it has expressions left in its buffer + // that haven’t been flushed using the Query’s Flush method. + ErrNeedsFlush = errors.New("Query has dangling buffer, its Flush method needs to be called") +) + +// Query represents an SQL query that is being built. It can be used from its empty value, +// or it can be instantiated with the New method. +// +// Query instances are used to build SQL query string and argument lists, and consist of an +// SQL string and a buffer. The Flush method must be called before the Query is used, or you +// may leave expressions dangling in the buffer. +// +// The Query type is not meant to be concurrency-safe; if you need to modify it from multiple +// goroutines, you need to coordinate that access yourself. type Query struct { sql string args []interface{} @@ -20,14 +39,20 @@ type Query struct { includesOrder bool } +// ColumnList represents a set of columns. type ColumnList []string +// String returns the columns in the ColumnList, joined by ", ", often used to create an +// SQL-formatted list of column names. func (c ColumnList) String() string { return strings.Join(c, ", ") } +// Flag represents a modification to the returned values from our Column or Columns functions. +// See the constants defined in this package for valid values. type Flag int +// New returns a new Query instance, primed for use. func New(query string) *Query { return &Query{ sql: query, @@ -35,6 +60,9 @@ func New(query string) *Query { } } +// Insert returns a Query instance containing SQL that will insert the passed `values` into +// the database. All `values` will be inserted into the same table, so invalid SQL will be +// generated if all `values` are not the same type. func Insert(obj SQLTableNamer, values ...SQLTableNamer) *Query { inserts := make([]SQLTableNamer, 0, len(values)+1) inserts = append(inserts, obj) @@ -49,11 +77,16 @@ func Insert(obj SQLTableNamer, values ...SQLTableNamer) *Query { return query.Flush(" ") } +// ErrWrongNumberArgs is returned when you’ve generated a Query with a certain number of +// placeholders, but supplied a different number of arguments. The NumExpected property +// holds the number of placeholders in the Query, and the NumFound property holds the +// number of arguments supplied. type ErrWrongNumberArgs struct { NumExpected int NumFound int } +// Error fills the error interface. func (e ErrWrongNumberArgs) Error() string { return fmt.Sprintf("Expected %d arguments, got %d.", e.NumExpected, e.NumFound) } @@ -67,6 +100,10 @@ func (q *Query) checkCounts() error { return nil } +// String returns a version of your Query with all the arguments in the place of their +// placeholders. It does not do any sanitization, and is vulnerable to SQL injection. +// It is meant as a debugging aid, not to be executed. The string will almost certainly +// not be valid SQL. func (q *Query) String() string { var argPos int var res string @@ -86,14 +123,30 @@ func (q *Query) String() string { return res } +// MySQLString returns a SQL string that can be passed to MySQL to execute your query. +// If the number of placeholders do not match the number of arguments provided to your +// Query, an ErrWrongNumberArgs error will be returned. If there are still expressions +// left in the buffer (meaning the Flush method wasn't called) an ErrNeedsFlush error +// will be returned. func (q *Query) MySQLString() (string, error) { + if len(q.expressions) != 0 { + return "", ErrNeedsFlush + } if err := q.checkCounts(); err != nil { return "", err } return q.sql + ";", nil } +// PostgreSQLString returns an SQL string that can be passed to PostgreSQL to execute +// your query. If the number of placeholders do not match the number of arguments +// provided to your Query, an ErrWrongNumberArgs error will be returned. If there are +// still expressions left in the buffer (meaning the Flush method wasn't called) an +// ErrNeedsFlush error will be returned. func (q *Query) PostgreSQLString() (string, error) { + if len(q.expressions) != 0 { + return "", ErrNeedsFlush + } if err := q.checkCounts(); err != nil { return "", err } @@ -110,6 +163,11 @@ func (q *Query) PostgreSQLString() (string, error) { return res + ";", nil } +// Flush flushes the expressions in the Query’s buffer, adding them to the SQL string +// being built. It must be called before a Query can be used. Any pending expressions +// (anything since the last Flush or since the Query was instantiated) are joined using +// `join`, then added onto the Query’s SQL string, with a space between the SQL string +// and the expressions. func (q *Query) Flush(join string) *Query { q.sql = strings.TrimSpace(q.sql) + " " q.sql += strings.TrimSpace(strings.Join(q.expressions, join)) @@ -117,12 +175,18 @@ func (q *Query) Flush(join string) *Query { return q } +// Expression adds a raw string and optional values to the Query’s buffer. func (q *Query) Expression(key string, values ...interface{}) *Query { q.expressions = append(q.expressions, key) q.args = append(q.args, values...) return q } +// Where adds a WHERE keyword to the Query’s buffer, then calls Flush on the Query, +// using a space as the join parameter. +// +// Where can only be called once per Query; calling it multiple times on the same Query +// will be no-ops after the first. func (q *Query) Where() *Query { if q.includesWhere { return q @@ -133,14 +197,26 @@ func (q *Query) Where() *Query { return q } +// Comparison adds a comparison expression to the Query’s buffer. A comparison takes the +// form of `column operator ?`, with `value` added as an argument to the Query. Column is +// determined by finding the column name for the passed property on the passed SQLTableNamer. +// The passed property must be a string that matches, identically, the property name; if it +// does not, it will panic. func (q *Query) Comparison(obj SQLTableNamer, property, operator string, value interface{}) *Query { return q.Expression(Column(obj, property)+" "+operator+" ?", value) } +// In adds an expression to the Query’s buffer in the form of "column IN (value, value, value)". +// `values` are the variables to match against, and `obj` and `property` are used to determine +// the column. `property` must exactly match the name of a property on `obj`, or the call will +// panic. func (q *Query) In(obj SQLTableNamer, property string, values ...interface{}) *Query { return q.Expression(Column(obj, property)+" IN("+Placeholders(len(values))+")", values...) } +// Assign adds an expression to the Query’s buffer in the form of "column = ?", and adds `value` +// to the arguments for this query. `obj` and `property` are used to determine the column. +// `property` must exactly match the name of a property on `obj`, or the call will panic. func (q *Query) Assign(obj SQLTableNamer, property string, value interface{}) *Query { return q.Expression(Column(obj, property)+" = ?", value) } @@ -155,22 +231,33 @@ func (q *Query) orderBy(orderClause, dir string) *Query { return q } +// OrderBy adds an expression to the Query’s buffer in the form of "ORDER BY column". func (q *Query) OrderBy(column string) *Query { return q.orderBy(column, "") } +// OrderByDesc adds an expression to the Query’s buffer in the form of "ORDER BY column DESC". func (q *Query) OrderByDesc(column string) *Query { return q.orderBy(column, " DESC") } +// Limit adds an expression to the Query’s buffer in the form of "LIMIT ?", and adds `limit` as +// an argument to the Query. func (q *Query) Limit(limit int64) *Query { return q.Expression("LIMIT ?", limit) } +// Offset adds an expression to the Query’s buffer in the form of "OFFSET ?", and adds `offset` +// as an argument to the Query. func (q *Query) Offset(offset int64) *Query { return q.Expression("OFFSET ?", offset) } +// Args returns a slice of the arguments attached to the Query, which should be used when executing +// your SQL to fill the placeholders. +// +// Note that Args returns its internal slice; you should copy the returned slice over before modifying +// it. func (q *Query) Args() []interface{} { return q.args } diff --git a/query_test.go b/query_test.go index b232249..ac67233 100644 --- a/query_test.go +++ b/query_test.go @@ -106,6 +106,16 @@ var queryTests = []queryTest{ args: []interface{}{0}, }, }, + queryTest{ + ExpectedResult: queryResult{ + err: ErrNeedsFlush, + }, + Query: &Query{ + sql: "SELECT * FROM mytable WHERE", + args: []interface{}{0}, + expressions: []string{"this = ?"}, + }, + }, } func init() { @@ -239,6 +249,6 @@ func BenchmarkQueryString(b *testing.B) { b.StopTimer() test := queryTests[b.N%len(queryTests)] b.StartTimer() - test.Query.String() + _ = test.Query.String() } } diff --git a/reflect.go b/reflect.go index bc6991a..87c719d 100644 --- a/reflect.go +++ b/reflect.go @@ -150,11 +150,16 @@ func readStruct(s SQLTableNamer, needsValues bool, flags ...Flag) (columns []str return decorateColumns(columns, s.GetSQLTableName(), flags...), values } +// Columns returns a ColumnList containing the names of the columns +// in `s`. func Columns(s SQLTableNamer, flags ...Flag) ColumnList { columns, _ := readStruct(s, false, flags...) return columns } +// Column returns the name of the column that `property` maps to for `s`. +// `property` must be the exact name of a property on `s`, or Column will +// panic. func Column(s SQLTableNamer, property string, flags ...Flag) string { t := reflect.TypeOf(s) k := t.Kind() @@ -173,19 +178,28 @@ func Column(s SQLTableNamer, property string, flags ...Flag) string { return columns[0] } +// ColumnValues returns the values in `s` for each column in `s`, in the +// same order `Columns` returns the names. func ColumnValues(s SQLTableNamer) []interface{} { _, values := readStruct(s, true) return values } +// SQLTableNamer is used to represent a type that corresponds to an SQL +// table. It must define the GetSQLTableName method, returning the name +// of the SQL table to store data for that type in. type SQLTableNamer interface { GetSQLTableName() string } +// Table is a convenient shorthand wrapper for the GetSQLTableName method +// on `t`. func Table(t SQLTableNamer) string { return t.GetSQLTableName() } +// Placeholders returns a formatted string containing `num` placeholders. +// The placeholders will be comma-separated. func Placeholders(num int) string { placeholders := make([]string, num) for pos := 0; pos < num; pos++ { @@ -194,6 +208,9 @@ func Placeholders(num int) string { return strings.Join(placeholders, ", ") } +// Scannable defines a type that can insert the results of a Query into +// the SQLTableNamer a Query was built from, and can list off the column +// names, in order, that those results represent. type Scannable interface { Scan(dst ...interface{}) error Columns() ([]string, error) @@ -236,6 +253,11 @@ func getColumnAddrs(s Scannable, in []pointer) ([]interface{}, error) { return i, nil } +// Unmarshal reads the Scannable `s` into the variable at `d`, and returns an +// error if it is unable to. If there are more values than `d` has properties +// associated with columns, `additional` can be supplied to catch the extra values. +// The variables in `additional` must be a compatible type with and be in the same +// order as the columns of `s`. func Unmarshal(s Scannable, dst interface{}, additional ...interface{}) error { t := reflect.TypeOf(dst) v := reflect.ValueOf(dst) diff --git a/reflect_test.go b/reflect_test.go index 7bd69b9..8428cfb 100644 --- a/reflect_test.go +++ b/reflect_test.go @@ -98,16 +98,16 @@ func TestCamelToSnake(t *testing.T) { } } -type invalidSqlFieldReflector string +type invalidSQLFieldReflector string -func (i invalidSqlFieldReflector) GetSQLTableName() string { +func (i invalidSQLFieldReflector) GetSQLTableName() string { return "invalid_reflection_table" } func TestInvalidFieldReflection(t *testing.T) { t.Parallel() - columns := Columns(invalidSqlFieldReflector("test")) - values := ColumnValues(invalidSqlFieldReflector("test")) + columns := Columns(invalidSQLFieldReflector("test")) + values := ColumnValues(invalidSQLFieldReflector("test")) if len(columns) != 0 { t.Errorf("Expected %d columns, got %d.", 0, len(columns)) } From 5d3855a0d32d31fa8b630acdc050b6bcb61691e5 Mon Sep 17 00:00:00 2001 From: Paddy Date: Mon, 25 Jul 2016 21:54:16 -0700 Subject: [PATCH 21/21] =?UTF-8?q?Bail=20out=20of=20Flush=20if=20it?= =?UTF-8?q?=E2=80=99s=20already=20been=20called.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If there’s nothing for Flush to do, then just bail out. This makes sure it’s safe to call Flush multiple times. --- query.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/query.go b/query.go index cfd8be0..66766ac 100644 --- a/query.go +++ b/query.go @@ -169,6 +169,9 @@ func (q *Query) PostgreSQLString() (string, error) { // `join`, then added onto the Query’s SQL string, with a space between the SQL string // and the expressions. func (q *Query) Flush(join string) *Query { + if len(q.expressions) < 1 { + return q + } q.sql = strings.TrimSpace(q.sql) + " " q.sql += strings.TrimSpace(strings.Join(q.expressions, join)) q.expressions = q.expressions[0:0]