diff --git a/db/interfaces.go b/db/interfaces.go index f08e25888fe..99f701d4a68 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -58,17 +58,9 @@ type Executor interface { OneSelector Inserter SelectExecer - Queryer Delete(context.Context, ...interface{}) (int64, error) Get(context.Context, interface{}, ...interface{}) (interface{}, error) Update(context.Context, ...interface{}) (int64, error) -} - -// Queryer offers the QueryContext method. Note that this is not read-only (i.e. not -// Selector), since a QueryContext can be `INSERT`, `UPDATE`, etc. The difference -// between QueryContext and ExecContext is that QueryContext can return rows. So for instance it is -// suitable for inserting rows and getting back ids. -type Queryer interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) } diff --git a/db/multi.go b/db/multi.go index bcb2fbe3fc5..e9a146ce55a 100644 --- a/db/multi.go +++ b/db/multi.go @@ -7,29 +7,24 @@ import ( ) // MultiInserter makes it easy to construct a -// `INSERT INTO table (...) VALUES ... RETURNING id;` +// `INSERT INTO table (...) VALUES ...;` // query which inserts multiple rows into the same table. It can also execute // the resulting query. type MultiInserter struct { // These are validated by the constructor as containing only characters // that are allowed in an unquoted identifier. // https://mariadb.com/kb/en/identifier-names/#unquoted - table string - fields []string - returningColumn string + table string + fields []string values [][]interface{} } // NewMultiInserter creates a new MultiInserter, checking for reasonable table -// name and list of fields. returningColumn is the name of a column to be used -// in a `RETURNING xyz` clause at the end. If it is empty, no `RETURNING xyz` -// clause is used. If returningColumn is present, it must refer to a column -// that can be parsed into an int64. -// Safety: `table`, `fields`, and `returningColumn` must contain only strings -// that are known at compile time. They must not contain user-controlled -// strings. -func NewMultiInserter(table string, fields []string, returningColumn string) (*MultiInserter, error) { +// name and list of fields. +// Safety: `table` and `fields` must contain only strings that are known at +// compile time. They must not contain user-controlled strings. +func NewMultiInserter(table string, fields []string) (*MultiInserter, error) { if len(table) == 0 || len(fields) == 0 { return nil, fmt.Errorf("empty table name or fields list") } @@ -44,18 +39,11 @@ func NewMultiInserter(table string, fields []string, returningColumn string) (*M return nil, err } } - if returningColumn != "" { - err := validMariaDBUnquotedIdentifier(returningColumn) - if err != nil { - return nil, err - } - } return &MultiInserter{ - table: table, - fields: fields, - returningColumn: returningColumn, - values: make([][]interface{}, 0), + table: table, + fields: fields, + values: make([][]interface{}, 0), }, nil } @@ -84,56 +72,32 @@ func (mi *MultiInserter) query() (string, []interface{}) { questions := strings.TrimRight(questionsBuf.String(), ",") - // Safety: we are interpolating `mi.returningColumn` into an SQL query. We - // know it is a valid unquoted identifier in MariaDB because we verified - // that in the constructor. - returning := "" - if mi.returningColumn != "" { - returning = fmt.Sprintf(" RETURNING %s", mi.returningColumn) - } // Safety: we are interpolating `mi.table` and `mi.fields` into an SQL // query. We know they contain, respectively, a valid unquoted identifier // and a slice of valid unquoted identifiers because we verified that in // the constructor. We know the query overall has valid syntax because we // generate it entirely within this function. - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s%s", mi.table, strings.Join(mi.fields, ","), questions, returning) + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", mi.table, strings.Join(mi.fields, ","), questions) return query, queryArgs } // Insert inserts all the collected rows into the database represented by -// `queryer`. If a non-empty returningColumn was provided, then it returns -// the list of values from that column returned by the query. -func (mi *MultiInserter) Insert(ctx context.Context, queryer Queryer) ([]int64, error) { +// `queryer`. +func (mi *MultiInserter) Insert(ctx context.Context, db Execer) error { query, queryArgs := mi.query() - rows, err := queryer.QueryContext(ctx, query, queryArgs...) + res, err := db.ExecContext(ctx, query, queryArgs...) if err != nil { - return nil, err + return err } - ids := make([]int64, 0, len(mi.values)) - if mi.returningColumn != "" { - for rows.Next() { - var id int64 - err = rows.Scan(&id) - if err != nil { - rows.Close() - return nil, err - } - ids = append(ids, id) - } + affected, err := res.RowsAffected() + if err != nil { + return err } - - // Hack: sometimes in unittests we make a mock Queryer that returns a nil - // `*sql.Rows`. A nil `*sql.Rows` is not actually valid— calling `Close()` - // on it will panic— but here we choose to treat it like an empty list, - // and skip calling `Close()` to avoid the panic. - if rows != nil { - err = rows.Close() - if err != nil { - return nil, err - } + if affected != int64(len(mi.values)) { + return fmt.Errorf("unexpected number of rows inserted: %d != %d", affected, len(mi.values)) } - return ids, nil + return nil } diff --git a/db/multi_test.go b/db/multi_test.go index f972f4748b0..d866699bff9 100644 --- a/db/multi_test.go +++ b/db/multi_test.go @@ -7,34 +7,29 @@ import ( ) func TestNewMulti(t *testing.T) { - _, err := NewMultiInserter("", []string{"colA"}, "") + _, err := NewMultiInserter("", []string{"colA"}) test.AssertError(t, err, "Empty table name should fail") - _, err = NewMultiInserter("myTable", nil, "") + _, err = NewMultiInserter("myTable", nil) test.AssertError(t, err, "Empty fields list should fail") - mi, err := NewMultiInserter("myTable", []string{"colA"}, "") + mi, err := NewMultiInserter("myTable", []string{"colA"}) test.AssertNotError(t, err, "Single-column construction should not fail") test.AssertEquals(t, len(mi.fields), 1) - mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"}, "") + mi, err = NewMultiInserter("myTable", []string{"colA", "colB", "colC"}) test.AssertNotError(t, err, "Multi-column construction should not fail") test.AssertEquals(t, len(mi.fields), 3) - _, err = NewMultiInserter("", []string{"colA"}, "colB") - test.AssertError(t, err, "expected error for empty table name") - _, err = NewMultiInserter("foo\"bar", []string{"colA"}, "colB") + _, err = NewMultiInserter("foo\"bar", []string{"colA"}) test.AssertError(t, err, "expected error for invalid table name") - _, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"}, "colB") + _, err = NewMultiInserter("myTable", []string{"colA", "foo\"bar"}) test.AssertError(t, err, "expected error for invalid column name") - - _, err = NewMultiInserter("myTable", []string{"colA"}, "foo\"bar") - test.AssertError(t, err, "expected error for invalid returning column name") } func TestMultiAdd(t *testing.T) { - mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "") + mi, err := NewMultiInserter("table", []string{"a", "b", "c"}) test.AssertNotError(t, err, "Failed to create test MultiInserter") err = mi.Add([]interface{}{}) @@ -57,7 +52,7 @@ func TestMultiAdd(t *testing.T) { } func TestMultiQuery(t *testing.T) { - mi, err := NewMultiInserter("table", []string{"a", "b", "c"}, "") + mi, err := NewMultiInserter("table", []string{"a", "b", "c"}) test.AssertNotError(t, err, "Failed to create test MultiInserter") err = mi.Add([]interface{}{"one", "two", "three"}) test.AssertNotError(t, err, "Failed to insert test row") @@ -67,15 +62,4 @@ func TestMultiQuery(t *testing.T) { query, queryArgs := mi.query() test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?)") test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"}) - - mi, err = NewMultiInserter("table", []string{"a", "b", "c"}, "id") - test.AssertNotError(t, err, "Failed to create test MultiInserter") - err = mi.Add([]interface{}{"one", "two", "three"}) - test.AssertNotError(t, err, "Failed to insert test row") - err = mi.Add([]interface{}{"egy", "kettö", "három"}) - test.AssertNotError(t, err, "Failed to insert test row") - - query, queryArgs = mi.query() - test.AssertEquals(t, query, "INSERT INTO table (a,b,c) VALUES (?,?,?),(?,?,?) RETURNING id") - test.AssertDeepEquals(t, queryArgs, []interface{}{"one", "two", "three", "egy", "kettö", "három"}) } diff --git a/features/features.go b/features/features.go index 25b3ba83d1d..40f2ad8caa4 100644 --- a/features/features.go +++ b/features/features.go @@ -29,6 +29,7 @@ type Config struct { CertCheckerRequiresCorrespondence bool ECDSAForAll bool CheckRenewalExemptionAtWFE bool + InsertAuthzsIndividually bool // ServeRenewalInfo exposes the renewalInfo endpoint in the directory and for // GET requests. WARNING: This feature is a draft and highly unstable. @@ -115,13 +116,6 @@ type Config struct { // // This flag should only be used in conjunction with UseKvLimitsForNewOrder. DisableLegacyLimitWrites bool - - // InsertAuthzsIndividually causes the SA's NewOrderAndAuthzs method to - // create each new authz one at a time, rather than using MultiInserter. - // Although this is expected to be a performance penalty, it is necessary to - // get the AUTO_INCREMENT ID of each new authz without relying on MariaDB's - // unique "INSERT ... RETURNING" functionality. - InsertAuthzsIndividually bool } var fMu = new(sync.RWMutex) diff --git a/sa/model.go b/sa/model.go index fa3ce717a29..ac99370aaa2 100644 --- a/sa/model.go +++ b/sa/model.go @@ -1047,12 +1047,12 @@ func deleteOrderFQDNSet( return nil } -func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certificate, isRenewal bool) error { +func addIssuedNames(ctx context.Context, queryer db.Execer, cert *x509.Certificate, isRenewal bool) error { if len(cert.DNSNames) == 0 { return berrors.InternalServerError("certificate has no DNSNames") } - multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"}, "") + multiInserter, err := db.NewMultiInserter("issuedNames", []string{"reversedName", "serial", "notBefore", "renewal"}) if err != nil { return err } @@ -1067,8 +1067,7 @@ func addIssuedNames(ctx context.Context, queryer db.Queryer, cert *x509.Certific return err } } - _, err = multiInserter.Insert(ctx, queryer) - return err + return multiInserter.Insert(ctx, queryer) } func addKeyHash(ctx context.Context, db db.Inserter, cert *x509.Certificate) error { diff --git a/sa/sa.go b/sa/sa.go index 05ca1b02e56..74c58a4a47e 100644 --- a/sa/sa.go +++ b/sa/sa.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "strings" "time" "github.com/jmhodges/clock" @@ -473,53 +472,17 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb output, err := db.WithTransaction(ctx, ssa.dbMap, func(tx db.Executor) (interface{}, error) { // First, insert all of the new authorizations and record their IDs. - newAuthzIDs := make([]int64, 0) - if features.Get().InsertAuthzsIndividually { - for _, authz := range req.NewAuthzs { - am, err := newAuthzReqToModel(authz) - if err != nil { - return nil, err - } - err = tx.Insert(ctx, am) - if err != nil { - return nil, err - } - newAuthzIDs = append(newAuthzIDs, am.ID) + newAuthzIDs := make([]int64, 0, len(req.NewAuthzs)) + for _, authz := range req.NewAuthzs { + am, err := newAuthzReqToModel(authz) + if err != nil { + return nil, err } - } else { - if len(req.NewAuthzs) != 0 { - inserter, err := db.NewMultiInserter("authz2", strings.Split(authzFields, ", "), "id") - if err != nil { - return nil, err - } - for _, authz := range req.NewAuthzs { - am, err := newAuthzReqToModel(authz) - if err != nil { - return nil, err - } - err = inserter.Add([]interface{}{ - am.ID, - am.IdentifierType, - am.IdentifierValue, - am.RegistrationID, - statusToUint[core.StatusPending], - am.Expires, - am.Challenges, - nil, - nil, - am.Token, - nil, - nil, - }) - if err != nil { - return nil, err - } - } - newAuthzIDs, err = inserter.Insert(ctx, tx) - if err != nil { - return nil, err - } + err = tx.Insert(ctx, am) + if err != nil { + return nil, err } + newAuthzIDs = append(newAuthzIDs, am.ID) } // Second, insert the new order. @@ -549,7 +512,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb } // Third, insert all of the orderToAuthz relations. - inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"}, "") + inserter, err := db.NewMultiInserter("orderToAuthz2", []string{"orderID", "authzID"}) if err != nil { return nil, err } @@ -565,7 +528,7 @@ func (ssa *SQLStorageAuthority) NewOrderAndAuthzs(ctx context.Context, req *sapb return nil, err } } - _, err = inserter.Insert(ctx, tx) + err = inserter.Insert(ctx, tx) if err != nil { return nil, err } diff --git a/sa/sa_test.go b/sa/sa_test.go index 5e24e349de1..e909fe155f0 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -9,7 +9,6 @@ import ( "crypto/x509" "database/sql" "encoding/base64" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -1063,15 +1062,28 @@ func TestFQDNSetsExists(t *testing.T) { test.Assert(t, exists.Exists, "FQDN set does exist") } -type queryRecorder struct { - query string - args []interface{} +type execRecorder struct { + valuesPerRow int + query string + args []interface{} } -func (e *queryRecorder) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (e *execRecorder) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { e.query = query e.args = args - return nil, nil + return rowsResult{int64(len(args) / e.valuesPerRow)}, nil +} + +type rowsResult struct { + rowsAffected int64 +} + +func (r rowsResult) LastInsertId() (int64, error) { + return r.rowsAffected, nil +} + +func (r rowsResult) RowsAffected() (int64, error) { + return r.rowsAffected, nil } func TestAddIssuedNames(t *testing.T) { @@ -1154,7 +1166,7 @@ func TestAddIssuedNames(t *testing.T) { for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { - var e queryRecorder + e := execRecorder{valuesPerRow: 4} err := addIssuedNames( ctx, &e, @@ -1232,9 +1244,6 @@ func TestNewOrderAndAuthzs(t *testing.T) { sa, _, cleanup := initSA(t) defer cleanup() - features.Set(features.Config{InsertAuthzsIndividually: true}) - defer features.Reset() - reg := createWorkingRegistration(t, sa) // Insert two pre-existing authorizations to reference @@ -1289,9 +1298,6 @@ func TestNewOrderAndAuthzs_NonNilInnerOrder(t *testing.T) { sa, fc, cleanup := initSA(t) defer cleanup() - features.Set(features.Config{InsertAuthzsIndividually: true}) - defer features.Reset() - reg := createWorkingRegistration(t, sa) expires := fc.Now().Add(2 * time.Hour) @@ -1313,9 +1319,6 @@ func TestNewOrderAndAuthzs_MismatchedRegID(t *testing.T) { sa, _, cleanup := initSA(t) defer cleanup() - features.Set(features.Config{InsertAuthzsIndividually: true}) - defer features.Reset() - _, err := sa.NewOrderAndAuthzs(context.Background(), &sapb.NewOrderAndAuthzsRequest{ NewOrder: &sapb.NewOrderRequest{ RegistrationID: 1, @@ -1334,9 +1337,6 @@ func TestNewOrderAndAuthzs_NewAuthzExpectedFields(t *testing.T) { sa, fc, cleanup := initSA(t) defer cleanup() - features.Set(features.Config{InsertAuthzsIndividually: true}) - defer features.Reset() - reg := createWorkingRegistration(t, sa) expires := fc.Now().Add(time.Hour) domain := "a.com" @@ -1385,55 +1385,6 @@ func TestNewOrderAndAuthzs_NewAuthzExpectedFields(t *testing.T) { test.AssertBoxedNil(t, am.ValidationRecord, "am.ValidationRecord should be nil") } -func BenchmarkNewOrderAndAuthzs(b *testing.B) { - for _, flag := range []bool{false, true} { - for _, numIdents := range []int{1, 2, 5, 10, 20, 50, 100} { - b.Run(fmt.Sprintf("%t/%d", flag, numIdents), func(b *testing.B) { - sa, _, cleanup := initSA(b) - defer cleanup() - - if flag { - features.Set(features.Config{InsertAuthzsIndividually: true}) - defer features.Reset() - } - - reg := createWorkingRegistration(b, sa) - - dnsNames := make([]string, 0, numIdents) - newAuthzs := make([]*sapb.NewAuthzRequest, 0, numIdents) - for range numIdents { - var nameBytes [3]byte - _, _ = rand.Read(nameBytes[:]) - name := fmt.Sprintf("%s.example.com", hex.EncodeToString(nameBytes[:])) - - dnsNames = append(dnsNames, name) - newAuthzs = append(newAuthzs, &sapb.NewAuthzRequest{ - RegistrationID: reg.Id, - Identifier: identifier.NewDNS(name).AsProto(), - ChallengeTypes: []string{string(core.ChallengeTypeDNS01)}, - Token: core.NewToken(), - Expires: timestamppb.New(sa.clk.Now().Add(24 * time.Hour)), - }) - } - - b.ResetTimer() - - _, err := sa.NewOrderAndAuthzs(context.Background(), &sapb.NewOrderAndAuthzsRequest{ - NewOrder: &sapb.NewOrderRequest{ - RegistrationID: reg.Id, - Expires: timestamppb.New(sa.clk.Now().Add(24 * time.Hour)), - DnsNames: dnsNames, - }, - NewAuthzs: newAuthzs, - }) - if err != nil { - b.Error(err) - } - }) - } - } -} - func TestSetOrderProcessing(t *testing.T) { sa, fc, cleanup := initSA(t) defer cleanup() diff --git a/test/config-next/sa.json b/test/config-next/sa.json index f271926f630..5afcf09153b 100644 --- a/test/config-next/sa.json +++ b/test/config-next/sa.json @@ -51,8 +51,7 @@ "features": { "MultipleCertificateProfiles": true, "TrackReplacementCertificatesARI": true, - "DisableLegacyLimitWrites": true, - "InsertAuthzsIndividually": true + "DisableLegacyLimitWrites": true } }, "syslog": { diff --git a/test/config/sa.json b/test/config/sa.json index 9dc66125661..591a23110ad 100644 --- a/test/config/sa.json +++ b/test/config/sa.json @@ -47,7 +47,9 @@ } } }, - "features": {} + "features": { + "InsertAuthzsIndividually": true + } }, "syslog": { "stdoutlevel": 6,