From 5242ecfda1830885ec1ef793903cb9a7f7bb7205 Mon Sep 17 00:00:00 2001 From: Goran Rojovic Date: Thu, 7 Nov 2024 11:19:42 +0100 Subject: [PATCH] fix: replace existing certificate --- aggsender/db/aggsender_db_storage.go | 36 +++++++++- aggsender/db/aggsender_db_storage_test.go | 83 +++++++++++++++++++++++ 2 files changed, 116 insertions(+), 3 deletions(-) diff --git a/aggsender/db/aggsender_db_storage.go b/aggsender/db/aggsender_db_storage.go index 400941a9..ec4ecade 100644 --- a/aggsender/db/aggsender_db_storage.go +++ b/aggsender/db/aggsender_db_storage.go @@ -89,9 +89,15 @@ func (a *AggSenderSQLStorage) GetCertificatesByStatus(ctx context.Context, // GetCertificateByHeight returns a certificate by its height func (a *AggSenderSQLStorage) GetCertificateByHeight(ctx context.Context, + height uint64) (types.CertificateInfo, error) { + return getCertificateByHeight(ctx, a.db, height) +} + +// getCertificateByHeight returns a certificate by its height using the provided db +func getCertificateByHeight(ctx context.Context, db meddler.DB, height uint64) (types.CertificateInfo, error) { var certificateInfo types.CertificateInfo - if err := meddler.QueryRow(a.db, &certificateInfo, + if err := meddler.QueryRow(db, &certificateInfo, "SELECT * FROM certificate_info WHERE height = $1;", height); err != nil { return types.CertificateInfo{}, getSelectQueryError(height, err) } @@ -124,9 +130,23 @@ func (a *AggSenderSQLStorage) SaveLastSentCertificate(ctx context.Context, certi } }() + cert, err := getCertificateByHeight(ctx, tx, certificate.Height) + if err != nil && !errors.Is(err, db.ErrNotFound) { + return err + } + + if cert.CertificateID != (common.Hash{}) { + // we already have a certificate with this height + // we need to delete it before inserting the new one + if err = deleteCertificate(ctx, tx, cert.CertificateID); err != nil { + return err + } + } + if err = meddler.Insert(tx, "certificate_info", &certificate); err != nil { return fmt.Errorf("error inserting certificate info: %w", err) } + if err = tx.Commit(); err != nil { return err } @@ -150,9 +170,10 @@ func (a *AggSenderSQLStorage) DeleteCertificate(ctx context.Context, certificate } }() - if _, err = tx.Exec(`DELETE FROM certificate_info WHERE certificate_id = $1;`, certificateID); err != nil { - return fmt.Errorf("error deleting certificate info: %w", err) + if err = deleteCertificate(ctx, a.db, certificateID); err != nil { + return err } + if err = tx.Commit(); err != nil { return err } @@ -162,6 +183,15 @@ func (a *AggSenderSQLStorage) DeleteCertificate(ctx context.Context, certificate return nil } +// deleteCertificate deletes a certificate from the storage using the provided db +func deleteCertificate(ctx context.Context, db meddler.DB, certificateID common.Hash) error { + if _, err := db.Exec(`DELETE FROM certificate_info WHERE certificate_id = $1;`, certificateID); err != nil { + return fmt.Errorf("error deleting certificate info: %w", err) + } + + return nil +} + // UpdateCertificateStatus updates the status of a certificate func (a *AggSenderSQLStorage) UpdateCertificateStatus(ctx context.Context, certificate types.CertificateInfo) error { tx, err := db.NewTx(ctx, a.db) diff --git a/aggsender/db/aggsender_db_storage_test.go b/aggsender/db/aggsender_db_storage_test.go index cfb7af7c..8b51daed 100644 --- a/aggsender/db/aggsender_db_storage_test.go +++ b/aggsender/db/aggsender_db_storage_test.go @@ -202,3 +202,86 @@ func Test_Storage(t *testing.T) { require.NoError(t, storage.clean()) }) } + +func Test_SaveLastSentCertificate(t *testing.T) { + ctx := context.Background() + + path := path.Join(t.TempDir(), "file::memory:?cache=shared") + log.Debugf("sqlite path: %s", path) + require.NoError(t, migrations.RunMigrations(path)) + + storage, err := NewAggSenderSQLStorage(log.WithFields("aggsender-db"), path) + require.NoError(t, err) + + t.Run("SaveNewCertificate", func(t *testing.T) { + certificate := types.CertificateInfo{ + Height: 1, + CertificateID: common.HexToHash("0x1"), + NewLocalExitRoot: common.HexToHash("0x2"), + FromBlock: 1, + ToBlock: 2, + Status: agglayer.Settled, + } + require.NoError(t, storage.SaveLastSentCertificate(ctx, certificate)) + + certificateFromDB, err := storage.GetCertificateByHeight(ctx, certificate.Height) + require.NoError(t, err) + require.Equal(t, certificate, certificateFromDB) + require.NoError(t, storage.clean()) + }) + + t.Run("UpdateExistingCertificate", func(t *testing.T) { + certificate := types.CertificateInfo{ + Height: 2, + CertificateID: common.HexToHash("0x3"), + NewLocalExitRoot: common.HexToHash("0x4"), + FromBlock: 3, + ToBlock: 4, + Status: agglayer.InError, + } + require.NoError(t, storage.SaveLastSentCertificate(ctx, certificate)) + + // Update the certificate with the same height + updatedCertificate := types.CertificateInfo{ + Height: 2, + CertificateID: common.HexToHash("0x5"), + NewLocalExitRoot: common.HexToHash("0x6"), + FromBlock: 3, + ToBlock: 6, + Status: agglayer.Pending, + } + require.NoError(t, storage.SaveLastSentCertificate(ctx, updatedCertificate)) + + certificateFromDB, err := storage.GetCertificateByHeight(ctx, updatedCertificate.Height) + require.NoError(t, err) + require.Equal(t, updatedCertificate, certificateFromDB) + require.NoError(t, storage.clean()) + }) + + t.Run("SaveCertificateWithRollback", func(t *testing.T) { + // Simulate an error during the transaction to trigger a rollback + certificate := types.CertificateInfo{ + Height: 3, + CertificateID: common.HexToHash("0x7"), + NewLocalExitRoot: common.HexToHash("0x8"), + FromBlock: 7, + ToBlock: 8, + Status: agglayer.Settled, + } + + // Close the database to force an error + require.NoError(t, storage.db.Close()) + + err := storage.SaveLastSentCertificate(ctx, certificate) + require.Error(t, err) + + // Reopen the database and check that the certificate was not saved + storage.db, err = db.NewSQLiteDB(path) + require.NoError(t, err) + + certificateFromDB, err := storage.GetCertificateByHeight(ctx, certificate.Height) + require.ErrorIs(t, err, db.ErrNotFound) + require.Equal(t, types.CertificateInfo{}, certificateFromDB) + require.NoError(t, storage.clean()) + }) +}