diff --git a/persistence/context.go b/persistence/context.go index 58ad0e51d..15a89dfe1 100644 --- a/persistence/context.go +++ b/persistence/context.go @@ -12,6 +12,7 @@ import ( "github.com/pokt-network/pocket/persistence/indexer" coreTypes "github.com/pokt-network/pocket/shared/core/types" "github.com/pokt-network/pocket/shared/modules" + "go.uber.org/multierr" ) var _ modules.PersistenceRWContext = &PostgresContext{} @@ -36,29 +37,35 @@ type PostgresContext struct { networkId string } -func (p *PostgresContext) NewSavePoint(bytes []byte) error { - p.logger.Info().Bool("TODO", true).Msg("NewSavePoint not implemented") +// SetSavePoint generates a new Savepoint for this context. +func (p *PostgresContext) SetSavePoint() error { + if err := p.stateTrees.Savepoint(); err != nil { + return err + } return nil } -// TECHDEBT(#327): Guarantee atomicity betweens `prepareBlock`, `insertBlock` and `storeBlock` for save points & rollbacks. -func (p *PostgresContext) RollbackToSavePoint(bytes []byte) error { - p.logger.Info().Bool("TODO", true).Msg("RollbackToSavePoint not fully implemented") - return p.tx.Rollback(context.TODO()) +// RollbackToSavepoint triggers a rollback for the current pgx transaction and the underylying submodule stores. +func (p *PostgresContext) RollbackToSavePoint() error { + ctx, _ := p.getCtxAndTx() + pgErr := p.tx.Rollback(ctx) + treesErr := p.stateTrees.Rollback() + return multierr.Combine(pgErr, treesErr) } -// IMPROVE(#361): Guarantee the integrity of the state // Full details in the thread from the PR review: https://github.com/pokt-network/pocket/pull/285#discussion_r1018471719 func (p *PostgresContext) ComputeStateHash() (string, error) { stateHash, err := p.stateTrees.Update(p.tx, uint64(p.Height)) if err != nil { return "", err } + if err := p.stateTrees.Commit(); err != nil { + return "", err + } p.stateHash = stateHash return p.stateHash, nil } -// TECHDEBT(#327): Make sure these operations are atomic func (p *PostgresContext) Commit(proposerAddr, quorumCert []byte) error { p.logger.Info().Int64("height", p.Height).Msg("About to commit block & context") diff --git a/persistence/db.go b/persistence/db.go index 2a65e7819..73e64cada 100644 --- a/persistence/db.go +++ b/persistence/db.go @@ -37,6 +37,7 @@ var protocolActorSchemas = []types.ProtocolActorSchema{ types.ValidatorActor, } +// TECHDEBT(#595): Properly handle context threading and passing for the entire persistence module func (pg *PostgresContext) getCtxAndTx() (context.Context, pgx.Tx) { return context.TODO(), pg.tx } diff --git a/persistence/docs/CHANGELOG.md b/persistence/docs/CHANGELOG.md index 1bdae0d55..7f6dfb0b3 100644 --- a/persistence/docs/CHANGELOG.md +++ b/persistence/docs/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.0.0.60] - 2023-07-11 + +- Adds savepoints and rollbacks implementation to TreeStore + ## [0.0.0.60] - 2023-06-26 - Add place-holder for local context and servicer token usage support methods diff --git a/persistence/trees/atomic_test.go b/persistence/trees/atomic_test.go new file mode 100644 index 000000000..06fdbaf8c --- /dev/null +++ b/persistence/trees/atomic_test.go @@ -0,0 +1,92 @@ +package trees + +import ( + "encoding/hex" + "testing" + + "github.com/golang/mock/gomock" + "github.com/pokt-network/pocket/logger" + mock_types "github.com/pokt-network/pocket/persistence/types/mocks" + "github.com/pokt-network/pocket/shared/modules" + mockModules "github.com/pokt-network/pocket/shared/modules/mocks" + + "github.com/stretchr/testify/require" +) + +const ( + // the root hash of a tree store where each tree is empty but present and initialized + h0 = "302f2956c084cc3e0e760cf1b8c2da5de79c45fa542f68a660a5fc494b486972" + // the root hash of a tree store where each tree has has key foo value bar added to it + h1 = "7d5712ea1507915c40e295845fa58773baa405b24b87e9d99761125d826ff915" +) + +func TestTreeStore_AtomicUpdatesWithSuccessfulRollback(t *testing.T) { + ctrl := gomock.NewController(t) + + mockTxIndexer := mock_types.NewMockTxIndexer(ctrl) + mockBus := mockModules.NewMockBus(ctrl) + mockPersistenceMod := mockModules.NewMockPersistenceModule(ctrl) + + mockBus.EXPECT().GetPersistenceModule().AnyTimes().Return(mockPersistenceMod) + mockPersistenceMod.EXPECT().GetTxIndexer().AnyTimes().Return(mockTxIndexer) + + ts := &treeStore{ + logger: logger.Global.CreateLoggerForModule(modules.TreeStoreSubmoduleName), + treeStoreDir: ":memory:", + } + require.NoError(t, ts.setupTrees()) + require.NotEmpty(t, ts.merkleTrees[TransactionsTreeName]) + + hash0 := ts.getStateHash() + require.NotEmpty(t, hash0) + require.Equal(t, hash0, h0) + + require.NoError(t, ts.Savepoint()) + + // insert test data into every tree + for _, treeName := range stateTreeNames { + err := ts.merkleTrees[treeName].tree.Update([]byte("foo"), []byte("bar")) + require.NoError(t, err) + } + + // commit the above changes + require.NoError(t, ts.Commit()) + + // assert state hash is changed + hash1 := ts.getStateHash() + require.NotEmpty(t, hash1) + require.NotEqual(t, hash0, hash1) + require.Equal(t, hash1, h1) + + // set a new savepoint + require.NoError(t, ts.Savepoint()) + require.NotEmpty(t, ts.prevState.merkleTrees) + require.NotEmpty(t, ts.prevState.rootTree) + // assert that savepoint creation doesn't mutate state hash + require.Equal(t, hash1, hex.EncodeToString(ts.prevState.rootTree.tree.Root())) + + // verify that creating a savepoint does not change state hash + hash2 := ts.getStateHash() + require.Equal(t, hash2, hash1) + require.Equal(t, hash2, h1) + + // validate that state tree was updated and a previous savepoint is created + for _, treeName := range stateTreeNames { + require.NotEmpty(t, ts.merkleTrees[treeName]) + require.NotEmpty(t, ts.prevState.merkleTrees[treeName]) + } + + // insert additional test data into all of the trees + for _, treeName := range stateTreeNames { + require.NoError(t, ts.merkleTrees[treeName].tree.Update([]byte("fiz"), []byte("buz"))) + } + + // rollback the changes made to the trees above BEFORE anything was committed + err := ts.Rollback() + require.NoError(t, err) + + // validate that the state hash is unchanged after new data was inserted but rolled back before commitment + hash3 := ts.getStateHash() + require.Equal(t, hash3, hash2) + require.Equal(t, hash3, h1) +} diff --git a/persistence/trees/main_test.go b/persistence/trees/main_test.go new file mode 100644 index 000000000..9d5615ecb --- /dev/null +++ b/persistence/trees/main_test.go @@ -0,0 +1,12 @@ +//go:build test + +package trees + +import ( + "crypto/sha256" + "hash" +) + +type TreeStore = treeStore + +var SMTTreeHasher hash.Hash = sha256.New() diff --git a/persistence/trees/module_test.go b/persistence/trees/module_test.go index 7c7bc660c..91ec5249f 100644 --- a/persistence/trees/module_test.go +++ b/persistence/trees/module_test.go @@ -48,7 +48,6 @@ func TestTreeStore_Create(t *testing.T) { treemod, err := trees.Create(mockBus, trees.WithTreeStoreDirectory(":memory:")) assert.NoError(t, err) - got := treemod.GetBus() assert.Equal(t, got, mockBus) diff --git a/persistence/trees/prove_test.go b/persistence/trees/prove_test.go new file mode 100644 index 000000000..5d6cdb4c3 --- /dev/null +++ b/persistence/trees/prove_test.go @@ -0,0 +1,90 @@ +package trees + +import ( + "fmt" + "testing" + + "github.com/pokt-network/pocket/persistence/kvstore" + "github.com/pokt-network/smt" + "github.com/stretchr/testify/require" +) + +func TestTreeStore_Prove(t *testing.T) { + nodeStore := kvstore.NewMemKVStore() + tree := smt.NewSparseMerkleTree(nodeStore, smtTreeHasher) + testTree := &stateTree{ + name: "test", + tree: tree, + nodeStore: nodeStore, + } + + require.NoError(t, testTree.tree.Update([]byte("key"), []byte("value"))) + require.NoError(t, testTree.tree.Commit()) + + treeStore := &treeStore{ + merkleTrees: make(map[string]*stateTree, 1), + } + treeStore.merkleTrees["test"] = testTree + + testCases := []struct { + name string + treeName string + key []byte + value []byte + valid bool + expectedErr error + }{ + { + name: "valid inclusion proof: key and value in tree", + treeName: "test", + key: []byte("key"), + value: []byte("value"), + valid: true, + expectedErr: nil, + }, + { + name: "valid exclusion proof: key not in tree", + treeName: "test", + key: []byte("key2"), + value: nil, + valid: true, + expectedErr: nil, + }, + { + name: "invalid proof: tree not in store", + treeName: "unstored tree", + key: []byte("key"), + value: []byte("value"), + valid: false, + expectedErr: fmt.Errorf("tree not found: %s", "unstored tree"), + }, + { + name: "invalid inclusion proof: key in tree, wrong value", + treeName: "test", + key: []byte("key"), + value: []byte("wrong value"), + valid: false, + expectedErr: nil, + }, + { + name: "invalid exclusion proof: key in tree", + treeName: "test", + key: []byte("key"), + value: nil, + valid: false, + expectedErr: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := treeStore.Prove(tc.treeName, tc.key, tc.value) + require.Equal(t, valid, tc.valid) + if tc.expectedErr == nil { + require.NoError(t, err) + return + } + require.ErrorAs(t, err, &tc.expectedErr) + }) + } +} diff --git a/persistence/trees/trees.go b/persistence/trees/trees.go index 2d47cdc43..8fdc43676 100644 --- a/persistence/trees/trees.go +++ b/persistence/trees/trees.go @@ -1,7 +1,15 @@ // package trees maintains a set of sparse merkle trees -// each backed by the KVStore interface. It offers an atomic +// each backed by the `KVStore` interface. It offers an atomic // commit and rollback mechanism for interacting with -// that core resource map of merkle trees. +// its core resource - a set of merkle trees. +// - `Update` is called, which will fetch and apply the contextual changes to the respective trees. +// - `Savepoint` is first called to create a new anchor in time that can be rolled back to +// - `Commit` must be called after any `Update` calls to persist changes applied to disk. +// - If `Rollback` is called at any point before committing, it rolls the TreeStore state back to the +// earlier savepoint. This means that the caller is responsible for correctly managing atomic updates +// of the TreeStore. +// In most contexts, this is from the perspective of the `utility/unit_of_work` package. + package trees import ( @@ -74,6 +82,9 @@ type stateTree struct { var _ modules.TreeStoreModule = &treeStore{} +// ErrFailedRollback is thrown when a rollback fails to reset the TreeStore to a known good state +var ErrFailedRollback = fmt.Errorf("failed to rollback") + // treeStore stores a set of merkle trees that it manages. // It fulfills the modules.treeStore interface // * It is responsible for atomic commit or rollback behavior of the underlying @@ -88,6 +99,18 @@ type treeStore struct { treeStoreDir string rootTree *stateTree merkleTrees map[string]*stateTree + + // prevState holds a previous view of the worldState. + // The tree store rolls back to this view if errors are encountered during block application. + prevState *worldState +} + +// worldState holds a (de)serializable view of the entire tree state. +// TECHDEBT(#566) - Hook this up to node CLI subcommands +type worldState struct { + treeStoreDir string + rootTree *stateTree + merkleTrees map[string]*stateTree } // GetTree returns the root hash and nodeStore for the matching tree stored in the TreeStore. @@ -241,9 +264,6 @@ func (t *treeStore) updateMerkleTrees(pgtx pgx.Tx, txi indexer.TxIndexer, height } } - if err := t.Commit(); err != nil { - return "", fmt.Errorf("failed to commit: %w", err) - } return t.getStateHash(), nil } @@ -279,6 +299,67 @@ func (t *treeStore) getStateHash() string { return hexHash } +//////////////////////////////// +// AtomicStore Implementation // +//////////////////////////////// + +// Savepoint generates a new savepoint (i.e. a worldState) for the tree store and saves it internally. +func (t *treeStore) Savepoint() error { + w, err := t.save() + if err != nil { + return err + } + t.prevState = w + return nil +} + +// Rollback returns the treeStore to the last saved worldState maintained by the treeStore. +// If no worldState has been saved, it returns ErrFailedRollback +func (t *treeStore) Rollback() error { + if t.prevState != nil { + t.merkleTrees = t.prevState.merkleTrees + t.rootTree = t.prevState.rootTree + return nil + } + t.logger.Err(ErrFailedRollback) + return ErrFailedRollback +} + +// save commits any pending changes to the trees and creates a copy of the current worldState, +// then saves that copy as a rollback point for later use if errors are encountered. +// OPTIMIZE: Consider saving only the root hash of each tree and the tree directory here and then +// load the trees up in Rollback instead of setting them up here. +func (t *treeStore) save() (*worldState, error) { + if err := t.Commit(); err != nil { + return nil, err + } + + w := &worldState{ + treeStoreDir: t.treeStoreDir, + merkleTrees: map[string]*stateTree{}, + } + + for treeName := range t.merkleTrees { + root, nodeStore := t.GetTree(treeName) + tree := smt.ImportSparseMerkleTree(nodeStore, smtTreeHasher, root) + w.merkleTrees[treeName] = &stateTree{ + name: treeName, + tree: tree, + nodeStore: nodeStore, + } + } + + root, nodeStore := t.GetTree(RootTreeName) + tree := smt.ImportSparseMerkleTree(nodeStore, smtTreeHasher, root) + w.rootTree = &stateTree{ + name: RootTreeName, + tree: tree, + nodeStore: nodeStore, + } + + return w, nil +} + //////////////////////// // Actor Tree Helpers // //////////////////////// @@ -304,7 +385,6 @@ func (t *treeStore) updateActorsTree(actorType coreTypes.ActorType, actors []*co return err } } - return nil } diff --git a/persistence/trees/trees_test.go b/persistence/trees/trees_test.go index e59e3ba1f..aa8c41ab4 100644 --- a/persistence/trees/trees_test.go +++ b/persistence/trees/trees_test.go @@ -1,111 +1,205 @@ -package trees +package trees_test import ( - "fmt" + "encoding/hex" + "log" + "math/big" "testing" - "github.com/pokt-network/pocket/persistence/kvstore" - "github.com/pokt-network/smt" + "github.com/pokt-network/pocket/persistence" + "github.com/pokt-network/pocket/persistence/trees" + "github.com/pokt-network/pocket/runtime" + "github.com/pokt-network/pocket/runtime/configs" + "github.com/pokt-network/pocket/runtime/test_artifacts" + "github.com/pokt-network/pocket/runtime/test_artifacts/keygen" + core_types "github.com/pokt-network/pocket/shared/core/types" + "github.com/pokt-network/pocket/shared/crypto" + "github.com/pokt-network/pocket/shared/messaging" + "github.com/pokt-network/pocket/shared/modules" + "github.com/pokt-network/pocket/shared/utils" + "github.com/stretchr/testify/require" ) -// TECHDEBT(#836): Tests added in https://github.com/pokt-network/pocket/pull/836 +var ( + defaultChains = []string{"0001"} + defaultStakeBig = big.NewInt(1000000000000000) + defaultStake = utils.BigIntToString(defaultStakeBig) + defaultStakeStatus = int32(core_types.StakeStatus_Staked) + defaultPauseHeight = int64(-1) // pauseHeight=-1 implies not paused + defaultUnstakingHeight = int64(-1) // unstakingHeight=-1 implies not unstaking + + testSchema = "test_schema" + + genesisStateNumValidators = 5 + genesisStateNumServicers = 1 + genesisStateNumApplications = 1 +) + +const ( + treesHash1 = "5282ee91a3ec0a6f2b30e4780b369bae78c80ef3ea40587fef6ae263bf41f244" +) + func TestTreeStore_Update(t *testing.T) { - // TODO: Write test case for the Update method - t.Skip("TODO: Write test case for Update method") + pool, resource, dbUrl := test_artifacts.SetupPostgresDocker() + t.Cleanup(func() { + require.NoError(t, pool.Purge(resource)) + }) + + t.Run("should update actor trees, commit, and modify the state hash", func(t *testing.T) { + pmod := newTestPersistenceModule(t, dbUrl) + context := newTestPostgresContext(t, 0, pmod) + + require.NoError(t, context.SetSavePoint()) + + hash1, err := context.ComputeStateHash() + require.NoError(t, err) + require.NotEmpty(t, hash1) + require.Equal(t, hash1, treesHash1) + + _, err = createAndInsertDefaultTestApp(t, context) + require.NoError(t, err) + + require.NoError(t, context.SetSavePoint()) + + hash2, err := context.ComputeStateHash() + require.NoError(t, err) + require.NotEmpty(t, hash2) + require.NotEqual(t, hash1, hash2) + }) + + t.Run("should fail to rollback when no treestore savepoint is set", func(t *testing.T) { + pmod := newTestPersistenceModule(t, dbUrl) + context := newTestPostgresContext(t, 0, pmod) + + err := context.RollbackToSavePoint() + require.Error(t, err) + require.ErrorIs(t, err, trees.ErrFailedRollback) + }) } -func TestTreeStore_New(t *testing.T) { - // TODO: Write test case for the NewStateTrees function - t.Skip("TODO: Write test case for NewStateTrees function") +func newTestPersistenceModule(t *testing.T, databaseURL string) modules.PersistenceModule { + t.Helper() + teardownDeterministicKeygen := keygen.GetInstance().SetSeed(42) + defer teardownDeterministicKeygen() + + cfg := newTestDefaultConfig(t, databaseURL) + genesisState, _ := test_artifacts.NewGenesisState( + genesisStateNumValidators, + genesisStateNumServicers, + genesisStateNumApplications, + genesisStateNumServicers, + ) + + runtimeMgr := runtime.NewManager(cfg, genesisState) + + bus, err := runtime.CreateBus(runtimeMgr) + require.NoError(t, err) + + persistenceMod, err := persistence.Create(bus) + require.NoError(t, err) + + return persistenceMod.(modules.PersistenceModule) } -func TestTreeStore_DebugClearAll(t *testing.T) { - // TODO: Write test case for the DebugClearAll method - t.Skip("TODO: Write test case for DebugClearAll method") +// fetches a new default node configuration for testing +func newTestDefaultConfig(t *testing.T, databaseURL string) *configs.Config { + t.Helper() + cfg := &configs.Config{ + Persistence: &configs.PersistenceConfig{ + PostgresUrl: databaseURL, + NodeSchema: testSchema, + BlockStorePath: ":memory:", + TxIndexerPath: ":memory:", + TreesStoreDir: ":memory:", + MaxConnsCount: 5, + MinConnsCount: 1, + MaxConnLifetime: "5m", + MaxConnIdleTime: "1m", + HealthCheckPeriod: "30s", + }, + } + return cfg } +func createAndInsertDefaultTestApp(t *testing.T, db *persistence.PostgresContext) (*core_types.Actor, error) { + t.Helper() + app := newTestApp(t) -// TODO_AFTER(#861): Implement this test with the test suite available in #861 -func TestTreeStore_GetTreeHashes(t *testing.T) { - t.Skip("TODO: Write test case for GetTreeHashes method") // context: https://github.com/pokt-network/pocket/pull/915#discussion_r1267313664 + addrBz, err := hex.DecodeString(app.Address) + require.NoError(t, err) + + pubKeyBz, err := hex.DecodeString(app.PublicKey) + require.NoError(t, err) + + outputBz, err := hex.DecodeString(app.Output) + require.NoError(t, err) + return app, db.InsertApp( + addrBz, + pubKeyBz, + outputBz, + false, + defaultStakeStatus, + defaultStake, + defaultChains, + defaultPauseHeight, + defaultUnstakingHeight) } -func TestTreeStore_Prove(t *testing.T) { - nodeStore := kvstore.NewMemKVStore() - tree := smt.NewSparseMerkleTree(nodeStore, smtTreeHasher) - testTree := &stateTree{ - name: "test", - tree: tree, - nodeStore: nodeStore, - } +// TECHDEBT(#796): Test helpers should be consolidated in a single place +func newTestApp(t *testing.T) *core_types.Actor { + operatorKey, err := crypto.GeneratePublicKey() + require.NoError(t, err) - require.NoError(t, testTree.tree.Update([]byte("key"), []byte("value"))) - require.NoError(t, testTree.tree.Commit()) + outputAddr, err := crypto.GenerateAddress() + require.NoError(t, err) - treeStore := &treeStore{ - merkleTrees: make(map[string]*stateTree, 1), + return &core_types.Actor{ + Address: hex.EncodeToString(operatorKey.Address()), + PublicKey: hex.EncodeToString(operatorKey.Bytes()), + Chains: defaultChains, + StakedAmount: defaultStake, + PausedHeight: defaultPauseHeight, + UnstakingHeight: defaultUnstakingHeight, + Output: hex.EncodeToString(outputAddr), } - treeStore.merkleTrees["test"] = testTree - - testCases := []struct { - name string - treeName string - key []byte - value []byte - valid bool - expectedErr error - }{ - { - name: "valid inclusion proof: key and value in tree", - treeName: "test", - key: []byte("key"), - value: []byte("value"), - valid: true, - expectedErr: nil, - }, - { - name: "valid exclusion proof: key not in tree", - treeName: "test", - key: []byte("key2"), - value: nil, - valid: true, - expectedErr: nil, - }, - { - name: "invalid proof: tree not in store", - treeName: "unstored tree", - key: []byte("key"), - value: []byte("value"), - valid: false, - expectedErr: fmt.Errorf("tree not found: %s", "unstored tree"), - }, - { - name: "invalid inclusion proof: key in tree, wrong value", - treeName: "test", - key: []byte("key"), - value: []byte("wrong value"), - valid: false, - expectedErr: nil, - }, - { - name: "invalid exclusion proof: key in tree", - treeName: "test", - key: []byte("key"), - value: nil, - valid: false, - expectedErr: nil, - }, +} + +// TECHDEBT(#796): Test helpers should be consolidated in a single place +func newTestPostgresContext(t testing.TB, height int64, testPersistenceMod modules.PersistenceModule) *persistence.PostgresContext { + t.Helper() + rwCtx, err := testPersistenceMod.NewRWContext(height) + if err != nil { + log.Fatalf("Error creating new context: %v\n", err) } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - valid, err := treeStore.Prove(tc.treeName, tc.key, tc.value) - require.Equal(t, valid, tc.valid) - if tc.expectedErr == nil { - require.NoError(t, err) - return - } - require.ErrorAs(t, err, &tc.expectedErr) - }) + postgresCtx, ok := rwCtx.(*persistence.PostgresContext) + if !ok { + log.Fatalf("Error casting RW context to Postgres context") } + + // TECHDEBT: This should not be part of `NewTestPostgresContext`. It causes unnecessary resets + // if we call `NewTestPostgresContext` more than once in a single test. + t.Cleanup(func() { + resetStateToGenesis(testPersistenceMod) + }) + + return postgresCtx +} + +// This is necessary for unit tests that are dependant on a baseline genesis state +func resetStateToGenesis(m modules.PersistenceModule) { + if err := m.ReleaseWriteContext(); err != nil { + log.Fatalf("Error releasing write context: %v\n", err) + } + if err := m.HandleDebugMessage(&messaging.DebugMessage{ + Action: messaging.DebugMessageAction_DEBUG_PERSISTENCE_RESET_TO_GENESIS, + Message: nil, + }); err != nil { + log.Fatalf("Error clearing state: %v\n", err) + } +} + +// TODO_AFTER(#861): Implement this test with the test suite available in #861 +func TestTreeStore_GetTreeHashes(t *testing.T) { + t.Skip("TODO: Write test case for GetTreeHashes method") // context: https://github.com/pokt-network/pocket/pull/915#discussion_r1267313664 } diff --git a/shared/modules/persistence_module.go b/shared/modules/persistence_module.go index b3c1e56e2..38c7681ba 100644 --- a/shared/modules/persistence_module.go +++ b/shared/modules/persistence_module.go @@ -46,6 +46,13 @@ type PersistenceModule interface { GetLocalContext() (PersistenceLocalContext, error) } +// AtomicStore defines the interface for stores to implement to guarantee atomic commits to the persistence layer +type AtomicStore interface { + Savepoint() error + Commit() error + Rollback() error +} + // Interface defining the context within which the node can operate with the persistence layer. // Operations in the context of a PersistenceContext are isolated from other operations and // other persistence contexts until committed, enabling parallelizability along other operations. @@ -74,8 +81,8 @@ type PersistenceRWContext interface { // PersistenceWriteContext has no use-case independent of `PersistenceRWContext`, but is a useful abstraction type PersistenceWriteContext interface { // Context Operations - NewSavePoint([]byte) error - RollbackToSavePoint([]byte) error + SetSavePoint() error + RollbackToSavePoint() error Release() // Commits (and releases) the current context to disk (i.e. finality). diff --git a/shared/modules/treestore_module.go b/shared/modules/treestore_module.go index 35b240e51..a79f7a14f 100644 --- a/shared/modules/treestore_module.go +++ b/shared/modules/treestore_module.go @@ -21,13 +21,16 @@ type TreeStoreModule interface { Submodule treeStoreFactory - // Update returns the new state hash for a given height. + AtomicStore + + // Update returns the computed state hash for a given height. // * Height is passed through to the Update function and is used to query the TxIndexer for transactions // to update into the merkle tree set // * Passing a higher height will cause a change but repeatedly calling the same or a lower height will // not incur a change. // * By nature of it taking a pgx transaction at runtime, Update inherits the pgx transaction's read view of the // database. + // * Commit must be called after Update to persist any changes it made to disk. Update(pgtx pgx.Tx, height uint64) (string, error) // DebugClearAll completely clears the state of the trees. For debugging purposes only. DebugClearAll() error diff --git a/utility/unit_of_work/block.go b/utility/unit_of_work/block.go index 6076822b5..914cf1ac9 100644 --- a/utility/unit_of_work/block.go +++ b/utility/unit_of_work/block.go @@ -208,33 +208,18 @@ func (uow *baseUtilityUnitOfWork) prevBlockByzantineValidators() ([][]byte, erro return nil, nil } -// TODO: This has not been tested or investigated in detail -func (uow *baseUtilityUnitOfWork) revertLastSavePoint() coreTypes.Error { - // TODO(@deblasis): Implement this - // if len(u.savePointsSet) == 0 { - // return coreTypes.ErrEmptySavePoints() - // } - // var key []byte - // popIndex := len(u.savePointsList) - 1 - // key, u.savePointsList = u.savePointsList[popIndex], u.savePointsList[:popIndex] - // delete(u.savePointsSet, hex.EncodeToString(key)) - // if err := u.store.RollbackToSavePoint(key); err != nil { - // return coreTypes.ErrRollbackSavePoint(err) - // } +func (uow *baseUtilityUnitOfWork) revertToLastSavepoint() coreTypes.Error { + if err := uow.persistenceRWContext.RollbackToSavePoint(); err != nil { + uow.logger.Err(err).Msgf("failed to rollback to savepoint at height %d", uow.height) + return coreTypes.ErrRollbackSavePoint(err) + } return nil } -//nolint:unused // TODO: This has not been tested or investigated in detail -func (uow *baseUtilityUnitOfWork) newSavePoint(txHashBz []byte) coreTypes.Error { - // TODO(@deblasis): Implement this - // if err := u.store.NewSavePoint(txHashBz); err != nil { - // return coreTypes.ErrNewSavePoint(err) - // } - // txHash := hex.EncodeToString(txHashBz) - // if _, exists := u.savePointsSet[txHash]; exists { - // return coreTypes.ErrDuplicateSavePoint() - // } - // u.savePointsList = append(u.savePointsList, txHashBz) - // u.savePointsSet[txHash] = struct{}{} +func (uow *baseUtilityUnitOfWork) newSavePoint() coreTypes.Error { + if err := uow.persistenceRWContext.SetSavePoint(); err != nil { + uow.logger.Err(err).Msgf("failed to create new savepoint at height %d", uow.height) + return coreTypes.ErrNewSavePoint(err) + } return nil } diff --git a/utility/unit_of_work/module.go b/utility/unit_of_work/module.go index 2624e5bac..22547e090 100644 --- a/utility/unit_of_work/module.go +++ b/utility/unit_of_work/module.go @@ -1,12 +1,10 @@ package unit_of_work import ( - "fmt" - coreTypes "github.com/pokt-network/pocket/shared/core/types" - "github.com/pokt-network/pocket/shared/mempool" "github.com/pokt-network/pocket/shared/modules" "github.com/pokt-network/pocket/shared/modules/base_modules" + "go.uber.org/multierr" ) const ( @@ -48,6 +46,7 @@ func (uow *baseUtilityUnitOfWork) SetProposalBlock(blockHash string, proposerAdd return nil } +// ApplyBlock atomically applies a block to the persistence layer for a given height. func (uow *baseUtilityUnitOfWork) ApplyBlock() error { log := uow.logger.With().Fields(map[string]interface{}{ "source": "ApplyBlock", @@ -58,51 +57,55 @@ func (uow *baseUtilityUnitOfWork) ApplyBlock() error { return coreTypes.ErrProposalBlockNotSet() } + // initialize a new savepoint before applying the block + if err := uow.newSavePoint(); err != nil { + return err + } + // begin block lifecycle phase log.Debug().Msg("calling beginBlock") if err := uow.beginBlock(); err != nil { return err } + // processProposalBlockTransactions indexes the transactions into the TxIndexer. + // If it fails, it returns an error which triggers a rollback below to undo the changes + // that processProposalBlockTransactions could have caused. log.Debug().Msg("processing transactions from proposal block") - txMempool := uow.GetBus().GetUtilityModule().GetMempool() - if err := uow.processProposalBlockTransactions(txMempool); err != nil { - return err + if err := uow.processProposalBlockTransactions(); err != nil { + rollErr := uow.revertToLastSavepoint() + return multierr.Combine(rollErr, err) } - // end block lifecycle phase + // end block lifecycle phase calls endBlock and reverts to the last known savepoint if it encounters any errors log.Debug().Msg("calling endBlock") if err := uow.endBlock(uow.proposalProposerAddr); err != nil { - return err + rollErr := uow.revertToLastSavepoint() + return multierr.Combine(rollErr, err) } + // return the app hash (consensus module will get the validator set directly) - log.Debug().Msg("computing state hash") stateHash, err := uow.persistenceRWContext.ComputeStateHash() if err != nil { - log.Fatal().Err(err).Bool("TODO", true).Msg("Updating the app hash failed. TODO: Look into roll-backing the entire commit...") - return coreTypes.ErrAppHash(err) + rollErr := uow.persistenceRWContext.RollbackToSavePoint() + return coreTypes.ErrAppHash(multierr.Append(err, rollErr)) } // IMPROVE(#655): this acts as a feature flag to allow tests to ignore the check if needed, ideally the tests should have a way to determine // the hash and set it into the proposal block it's currently hard to do because the state is different at every test run (non-determinism) if uow.proposalStateHash != IgnoreProposalBlockCheckHash { if uow.proposalStateHash != stateHash { - log.Fatal().Bool("TODO", true). - Str("proposalStateHash", uow.proposalStateHash). - Str("stateHash", stateHash). - Msg("State hash mismatch. TODO: Look into roll-backing the entire commit...") - return coreTypes.ErrAppHash(fmt.Errorf("state hash mismatch: expected %s from the proposal, got %s", uow.proposalStateHash, stateHash)) + return uow.revertToLastSavepoint() } } - log.Info().Str("state_hash", stateHash).Msgf("ApplyBlock succeeded!") + log.Info().Str("state_hash", stateHash).Msgf("🧱 ApplyBlock succeeded!") uow.stateHash = stateHash return nil } -// TODO(@deblasis): change tracking here func (uow *baseUtilityUnitOfWork) Commit(quorumCert []byte) error { uow.logger.Debug().Msg("committing the rwPersistenceContext...") if err := uow.persistenceRWContext.Commit(uow.proposalProposerAddr, quorumCert); err != nil { @@ -112,7 +115,6 @@ func (uow *baseUtilityUnitOfWork) Commit(quorumCert []byte) error { return nil } -// TODO(@deblasis): change tracking reset here func (uow *baseUtilityUnitOfWork) Release() error { rwCtx := uow.persistenceRWContext if rwCtx != nil { @@ -138,9 +140,10 @@ func (uow *baseUtilityUnitOfWork) isProposalBlockSet() bool { // processProposalBlockTransactions processes the transactions from the proposal block stored in the current // unit of work. It applies the transactions to the persistence context, indexes them, and removes that from // the mempool if they are present. -func (uow *baseUtilityUnitOfWork) processProposalBlockTransactions(txMempool mempool.TXMempool) (err error) { +func (uow *baseUtilityUnitOfWork) processProposalBlockTransactions() (err error) { // CONSIDERATION: should we check that `uow.proposalBlockTxs` is not nil and return an error if so or allow empty blocks? // For reference, see Tendermint: https://docs.tendermint.com/v0.34/tendermint-core/configuration.html#empty-blocks-vs-no-empty-blocks + txMempool := uow.GetBus().GetUtilityModule().GetMempool() for index, txProtoBytes := range uow.proposalBlockTxs { tx, err := coreTypes.TxFromBytes(txProtoBytes) if err != nil { diff --git a/utility/unit_of_work/uow_leader.go b/utility/unit_of_work/uow_leader.go index 2c10d76dc..cfa2e6707 100644 --- a/utility/unit_of_work/uow_leader.go +++ b/utility/unit_of_work/uow_leader.go @@ -2,6 +2,7 @@ package unit_of_work import ( "encoding/hex" + "fmt" "github.com/pokt-network/pocket/logger" coreTypes "github.com/pokt-network/pocket/shared/core/types" @@ -58,7 +59,11 @@ func (uow *leaderUtilityUnitOfWork) CreateProposalBlock(proposer []byte, maxTxBy // Compute & return the new state hash stateHash, err = uow.persistenceRWContext.ComputeStateHash() if err != nil { - log.Fatal().Err(err).Bool("TODO", true).Msg("Updating the app hash failed. TODO: Look into roll-backing the entire commit...") + if err := uow.persistenceRWContext.RollbackToSavePoint(); err != nil { + log.Error().Msgf("failed to recover from rollback at height %+v: %+v", uow.height, err) + return "", nil, err + } + return "", nil, fmt.Errorf("rollback at height %d: failed to compute state hash: %w", uow.height, err) } log.Info().Str("state_hash", stateHash).Msg("Finished successfully") @@ -99,7 +104,7 @@ func (uow *leaderUtilityUnitOfWork) reapMempool(txMempool mempool.TXMempool, max if err != nil { uow.logger.Err(err).Msg("Error handling the transaction") // TODO(#327): Properly implement 'unhappy path' for save points - if err := uow.revertLastSavePoint(); err != nil { + if err := uow.revertToLastSavepoint(); err != nil { return nil, err } txsTotalBz -= txBzSize diff --git a/utility/unit_of_work/uow_leader_test.go b/utility/unit_of_work/uow_leader_test.go new file mode 100644 index 000000000..feef2f37d --- /dev/null +++ b/utility/unit_of_work/uow_leader_test.go @@ -0,0 +1,129 @@ +package unit_of_work + +import ( + "fmt" + "math/big" + "reflect" + "testing" + + "github.com/golang/mock/gomock" + "github.com/pokt-network/pocket/shared/modules" + mockModules "github.com/pokt-network/pocket/shared/modules/mocks" + "github.com/pokt-network/pocket/shared/utils" + "github.com/stretchr/testify/require" +) + +var DefaultStakeBig = big.NewInt(1000000000000000) + +func Test_leaderUtilityUnitOfWork_CreateProposalBlock(t *testing.T) { + t.Helper() + + type fields struct { + leaderUOW func(t *testing.T) *leaderUtilityUnitOfWork + } + type args struct { + proposer []byte + maxTxBytes uint64 + } + tests := []struct { + name string + args args + fields fields + wantStateHash string + wantTxs [][]byte + wantErr bool + }{ + { + name: "should revert a failed block proposal", + args: args{}, + fields: fields{ + leaderUOW: func(t *testing.T) *leaderUtilityUnitOfWork { + ctrl := gomock.NewController(t) + + mockrwcontext := newDefaultMockRWContext(t, ctrl) + mockrwcontext.EXPECT().RollbackToSavePoint().Times(1) + mockrwcontext.EXPECT().ComputeStateHash().Return("", fmt.Errorf("rollback error")) + + mockUtilityMod := newDefaultMockUtilityModule(t, ctrl) + mockbus := mockModules.NewMockBus(ctrl) + mockbus.EXPECT().GetUtilityModule().Return(mockUtilityMod).AnyTimes() + + luow := NewLeaderUOW(0, mockrwcontext, mockrwcontext) + luow.SetBus(mockbus) + + return luow + }, + }, + wantErr: true, + wantTxs: nil, + }, + { + name: "should apply a unit of work", + args: args{}, + fields: fields{ + leaderUOW: func(t *testing.T) *leaderUtilityUnitOfWork { + ctrl := gomock.NewController(t) + + mockrwcontext := newDefaultMockRWContext(t, ctrl) + mockrwcontext.EXPECT().ComputeStateHash().Return("foo", nil).Times(1) + + mockUtilityMod := newDefaultMockUtilityModule(t, ctrl) + mockbus := mockModules.NewMockBus(ctrl) + mockbus.EXPECT().GetUtilityModule().Return(mockUtilityMod).AnyTimes() + + luow := NewLeaderUOW(0, mockrwcontext, mockrwcontext) + luow.SetBus(mockbus) + + return luow + }, + }, + wantErr: false, + wantStateHash: "foo", + wantTxs: [][]byte{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + luow := tt.fields.leaderUOW(t) + gotHash, gotTxs, err := luow.CreateProposalBlock(tt.args.proposer, tt.args.maxTxBytes) + if (err != nil) != tt.wantErr { + t.Errorf("leaderUtilityUnitOfWork.CreateProposalBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + require.Equal(t, gotHash, tt.wantStateHash) + if !reflect.DeepEqual(gotTxs, tt.wantTxs) { + t.Errorf("leaderUtilityUnitOfWork.CreateProposalBlock() gotTxs = %v, want %v", gotTxs, tt.wantTxs) + } + }) + } +} + +func newDefaultMockRWContext(t *testing.T, ctrl *gomock.Controller) *mockModules.MockPersistenceRWContext { + t.Helper() + + mockrwcontext := mockModules.NewMockPersistenceRWContext(ctrl) + mockrwcontext.EXPECT().SetPoolAmount(gomock.Any(), gomock.Any()).AnyTimes() + mockrwcontext.EXPECT().GetIntParam(gomock.Any(), gomock.Any()).Return(0, nil).AnyTimes() + mockrwcontext.EXPECT().GetPoolAmount(gomock.Any(), gomock.Any()).Return(utils.BigIntToString(DefaultStakeBig), nil).Times(1) + mockrwcontext.EXPECT().AddAccountAmount(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockrwcontext.EXPECT().AddPoolAmount(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockrwcontext.EXPECT().GetAppsReadyToUnstake(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockrwcontext.EXPECT().GetServicersReadyToUnstake(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockrwcontext.EXPECT().GetValidatorsReadyToUnstake(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockrwcontext.EXPECT().GetFishermenReadyToUnstake(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockrwcontext.EXPECT().SetServicerStatusAndUnstakingHeightIfPausedBefore(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockrwcontext.EXPECT().SetAppStatusAndUnstakingHeightIfPausedBefore(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockrwcontext.EXPECT().SetValidatorsStatusAndUnstakingHeightIfPausedBefore(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockrwcontext.EXPECT().SetFishermanStatusAndUnstakingHeightIfPausedBefore(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + return mockrwcontext +} + +func newDefaultMockUtilityModule(t *testing.T, ctrl *gomock.Controller) *mockModules.MockUtilityModule { + mockUtilityMod := mockModules.NewMockUtilityModule(ctrl) + testmempool := NewTestingMempool(t) + mockUtilityMod.EXPECT().GetModuleName().Return(modules.UtilityModuleName).AnyTimes() + mockUtilityMod.EXPECT().SetBus(gomock.Any()).Return().AnyTimes() + mockUtilityMod.EXPECT().GetMempool().Return(testmempool).AnyTimes() + return mockUtilityMod +}