From 4bf7afad94ba98781f57dc87ec5b05d8c5926c10 Mon Sep 17 00:00:00 2001 From: Ben Krieger Date: Wed, 30 Oct 2024 13:00:09 -0400 Subject: [PATCH] Reimplement SQLite upsert, loadOrStore, and remove without transactions --- fdotest/server_state.go | 12 ++++ sqlite/sqlite.go | 144 +++++++++++++++++++++------------------- 2 files changed, 86 insertions(+), 70 deletions(-) diff --git a/fdotest/server_state.go b/fdotest/server_state.go index 1a7791a..24e5eb6 100644 --- a/fdotest/server_state.go +++ b/fdotest/server_state.go @@ -588,6 +588,18 @@ func RunServerStateSuite(t *testing.T, state AllServerState) { //nolint:gocyclo if _, err := state.Voucher(context.TODO(), newGUID); err != nil { t.Fatal(err) } + + // Remove voucher + removed, err := state.RemoveVoucher(context.TODO(), newGUID) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(removed, ov) { + t.Errorf("removed voucher should match replaced %+v, got %+v", ov, removed) + } + if _, err := state.RemoveVoucher(context.TODO(), newGUID); !errors.Is(err, fdo.ErrNotFound) { + t.Fatalf("removed voucher GUID should return not found, got error %v", err) + } }) t.Run("OwnerKeyPersistentState", func(t *testing.T) { diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index ff369fd..3074768 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -218,32 +218,20 @@ func (db *DB) NewToken(ctx context.Context, protocol protocol.Protocol) (string, } func (db *DB) loadOrStoreSecret(ctx context.Context) ([]byte, error) { - tx, err := db.db.BeginTx(ctx, nil) - if err != nil { - return nil, fmt.Errorf("error starting transaction: %w", err) - } - defer func() { _ = tx.Rollback() }() - - var readSecret []byte - if err := query(db.debugCtx(ctx), tx, "secrets", []string{"secret"}, map[string]any{"type": "hmac"}, &readSecret); err != nil && !errors.Is(err, fdo.ErrNotFound) { - return nil, fmt.Errorf("error reading hmac secret: %w", err) - } - if len(readSecret) > 0 { - return readSecret, nil - } - - // Insert new secret - var secret [64]byte - if _, err := rand.Read(secret[:]); err != nil { + // Insert (or ignore) a new HMAC secret + secret := make([]byte, 64) + if _, err := rand.Read(secret); err != nil { return nil, err } - if err := insert(db.debugCtx(ctx), tx, "secrets", map[string]any{"type": "hmac", "secret": secret[:]}, nil); err != nil { + if err := db.insertOrIgnore(ctx, "secrets", map[string]any{"type": "hmac", "secret": secret}); err != nil { return nil, fmt.Errorf("error writing hmac secret: %w", err) } - if err := tx.Commit(); err != nil { - return nil, err + + // Read secret + if err := db.query(ctx, "secrets", []string{"secret"}, map[string]any{"type": "hmac"}, &secret); err != nil { + return nil, fmt.Errorf("error reading hmac secret: %w", err) } - return secret[:], nil + return secret, nil } type contextKey struct{} @@ -308,20 +296,7 @@ func (db *DB) sessionID(ctx context.Context) ([]byte, bool) { } func (db *DB) insert(ctx context.Context, table string, kvs, upsertWhere map[string]any) error { - if len(upsertWhere) == 0 { - return insert(ctx, db.db, table, kvs, upsertWhere) - } - - tx, err := db.db.BeginTx(ctx, nil) - if err != nil { - return fmt.Errorf("error starting transaction: %w", err) - } - defer func() { _ = tx.Rollback() }() - - if err := insert(ctx, tx, table, kvs, upsertWhere); err != nil { - return err - } - return tx.Commit() + return insert(db.debugCtx(ctx), db.db, table, kvs, upsertWhere) } func (db *DB) insertOrIgnore(ctx context.Context, table string, kvs map[string]any) error { @@ -336,17 +311,25 @@ func (db *DB) query(ctx context.Context, table string, columns []string, where m return query(db.debugCtx(ctx), db.db, table, columns, where, into...) } +// Allows using *sql.DB or *sql.Tx type execer interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } +// Allows using *sql.DB or *sql.Tx type querier interface { QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } +// Allows using *sql.DB or *sql.Tx +type queryexecer interface { + querier + execer +} + func insert(ctx context.Context, db execer, table string, kvs, upsertWhere map[string]any) error { var orIgnore string - if upsertWhere != nil { + if upsertWhere != nil && len(upsertWhere) == 0 { orIgnore = "OR IGNORE " } @@ -357,22 +340,41 @@ func insert(ctx context.Context, db execer, table string, kvs, upsertWhere map[s } markers := slices.Repeat([]string{"?"}, len(columns)) + var upsert string + if len(upsertWhere) > 0 { + upsertWhereKeys := slices.Collect(maps.Keys(upsertWhere)) + + var updates []string + for _, key := range columns { + if slices.Contains(upsertWhereKeys, key) { + continue + } + updates = append(updates, fmt.Sprintf("`%s` = excluded.`%s`", key, key)) + } + + whereClauses := make([]string, len(upsertWhereKeys)) + for i, key := range upsertWhereKeys { + whereClauses[i] = fmt.Sprintf("`%s` = ?", key) + args = append(args, upsertWhere[key]) + } + + upsert = " ON CONFLICT DO UPDATE SET " + upsert += strings.Join(updates, " AND ") + upsert += " WHERE " + upsert += strings.Join(whereClauses, " AND ") + } + query := fmt.Sprintf( - "INSERT %sINTO %s (%s) VALUES (%s)", + "INSERT %sINTO %s (%s) VALUES (%s)%s", orIgnore, table, "`"+strings.Join(columns, "`, `")+"`", strings.Join(markers, ", "), + upsert, ) - debug(ctx, "sqlite: %s\n%+v", query, kvs) - if _, err := db.ExecContext(ctx, query, args...); err != nil { - return err - } - - if len(upsertWhere) > 0 { - return update(ctx, db, table, kvs, upsertWhere) - } - return nil + debug(ctx, "sqlite: %s\n%+v", query, args) + _, err := db.ExecContext(ctx, query, args...) + return err } func update(ctx context.Context, db execer, table string, kvs, where map[string]any) error { @@ -440,7 +442,7 @@ func query(ctx context.Context, db querier, table string, columns []string, wher return nil } -func remove(ctx context.Context, db execer, table string, where map[string]any) error { +func remove(ctx context.Context, db queryexecer, table string, where map[string]any, returning map[string]any) error { whereKeys := slices.Collect(maps.Keys(where)) clauses := make([]string, len(whereKeys)) for i, key := range whereKeys { @@ -451,12 +453,33 @@ func remove(ctx context.Context, db execer, table string, where map[string]any) whereVals[i] = where[key] } + var returningQuery string + returningArgs := make([]any, len(returning)) + if len(returning) > 0 { + returningKeys := slices.Collect(maps.Keys(returning)) + returningQuery = " RETURNING `" + strings.Join(returningKeys, "`, `") + "`" + for i, key := range returningKeys { + returningArgs[i] = returning[key] + } + } + query := fmt.Sprintf( - `DELETE FROM %s WHERE %s`, + `DELETE FROM %s WHERE %s%s`, table, strings.Join(clauses, " AND "), + returningQuery, ) - debug(ctx, "sqlite: %s\n%+v", query, where) + debug(ctx, "sqlite: %s\n%+v", query, whereVals) + + if returningQuery != "" { + row := db.QueryRowContext(ctx, query, whereVals...) + if err := row.Scan(returningArgs...); errors.Is(err, sql.ErrNoRows) { + return fdo.ErrNotFound + } else if err != nil { + return err + } + return nil + } result, err := db.ExecContext(ctx, query, whereVals...) if err != nil { @@ -829,19 +852,9 @@ func (db *DB) ReplaceVoucher(ctx context.Context, guid protocol.GUID, ov *fdo.Vo // RemoveVoucher untracks a voucher, deleting it, and returns it for extension. func (db *DB) RemoveVoucher(ctx context.Context, guid protocol.GUID) (*fdo.Voucher, error) { - ctx = db.debugCtx(ctx) - - tx, err := db.db.BeginTx(ctx, nil) - if err != nil { - return nil, fmt.Errorf("error starting transaction: %w", err) - } - defer func() { _ = tx.Rollback() }() - var data []byte - if err := query(ctx, tx, "owner_vouchers", []string{"cbor"}, - map[string]any{"guid": guid[:]}, - &data, - ); err != nil { + if err := remove(db.debugCtx(ctx), db.db, "owner_vouchers", + map[string]any{"guid": guid[:]}, map[string]any{"cbor": &data}); err != nil { return nil, err } if data == nil { @@ -852,15 +865,6 @@ func (db *DB) RemoveVoucher(ctx context.Context, guid protocol.GUID) (*fdo.Vouch if err := cbor.Unmarshal(data, &ov); err != nil { return nil, fmt.Errorf("error unmarshaling ownership voucher: %w", err) } - - if err := remove(ctx, tx, "owner_vouchers", map[string]any{"guid": guid[:]}); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - return &ov, nil }