From 071f780c0941297701a5ba468f6fe56e08e96800 Mon Sep 17 00:00:00 2001 From: yacovm Date: Mon, 21 Oct 2024 16:57:19 +0200 Subject: [PATCH 01/20] Make bootstrapping handle its own timeouts (#3410) Signed-off-by: Yacov Manevich --- chains/manager.go | 2 - message/internal_msg_builder.go | 18 --- message/ops.go | 4 - snow/engine/common/bootstrap_tracker.go | 4 +- snow/engine/common/engine.go | 3 - snow/engine/common/no_ops_handlers.go | 8 -- snow/engine/common/timer.go | 109 ++++++++++++++++- snow/engine/common/timer_test.go | 115 ++++++++++++++++++ snow/engine/common/traced_engine.go | 7 -- snow/engine/enginetest/bootstrap_tracker.go | 2 +- snow/engine/enginetest/engine.go | 14 --- snow/engine/enginetest/timer.go | 9 -- snow/engine/snowman/bootstrap/bootstrapper.go | 26 ++-- .../snowman/bootstrap/bootstrapper_test.go | 15 ++- snow/engine/snowman/bootstrap/config.go | 1 - snow/engine/snowman/engine.go | 4 - snow/networking/handler/handler.go | 31 ----- .../networking/handler/handlermock/handler.go | 12 -- subnets/subnet.go | 30 ++--- vms/platformvm/vm_test.go | 1 + 20 files changed, 267 insertions(+), 148 deletions(-) create mode 100644 snow/engine/common/timer_test.go diff --git a/chains/manager.go b/chains/manager.go index 61e40f789dd..906c6f136df 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -962,7 +962,6 @@ func (m *manager) createAvalancheChain( StartupTracker: startupTracker, Sender: snowmanMessageSender, BootstrapTracker: sb, - Timer: h, PeerTracker: peerTracker, AncestorsMaxContainersReceived: m.BootstrapAncestorsMaxContainersReceived, DB: blockBootstrappingDB, @@ -1357,7 +1356,6 @@ func (m *manager) createSnowmanChain( StartupTracker: startupTracker, Sender: messageSender, BootstrapTracker: sb, - Timer: h, PeerTracker: peerTracker, AncestorsMaxContainersReceived: m.BootstrapAncestorsMaxContainersReceived, DB: bootstrappingDB, diff --git a/message/internal_msg_builder.go b/message/internal_msg_builder.go index 2fabb2ae00c..1aa79ccb482 100644 --- a/message/internal_msg_builder.go +++ b/message/internal_msg_builder.go @@ -16,7 +16,6 @@ import ( var ( disconnected = &Disconnected{} gossipRequest = &GossipRequest{} - timeout = &Timeout{} _ fmt.Stringer = (*GetStateSummaryFrontierFailed)(nil) _ chainIDGetter = (*GetStateSummaryFrontierFailed)(nil) @@ -50,8 +49,6 @@ var ( _ fmt.Stringer = (*Disconnected)(nil) _ fmt.Stringer = (*GossipRequest)(nil) - - _ fmt.Stringer = (*Timeout)(nil) ) type GetStateSummaryFrontierFailed struct { @@ -391,18 +388,3 @@ func InternalGossipRequest( expiration: mockable.MaxTime, } } - -type Timeout struct{} - -func (Timeout) String() string { - return "" -} - -func InternalTimeout(nodeID ids.NodeID) InboundMessage { - return &inboundMessage{ - nodeID: nodeID, - op: TimeoutOp, - message: timeout, - expiration: mockable.MaxTime, - } -} diff --git a/message/ops.go b/message/ops.go index 6ac8a6aff9c..21728ddad05 100644 --- a/message/ops.go +++ b/message/ops.go @@ -59,7 +59,6 @@ const ( DisconnectedOp NotifyOp GossipRequestOp - TimeoutOp ) var ( @@ -115,7 +114,6 @@ var ( DisconnectedOp, NotifyOp, GossipRequestOp, - TimeoutOp, } ConsensusOps = append(ConsensusExternalOps, ConsensusInternalOps...) @@ -264,8 +262,6 @@ func (op Op) String() string { return "notify" case GossipRequestOp: return "gossip_request" - case TimeoutOp: - return "timeout" default: return "unknown" } diff --git a/snow/engine/common/bootstrap_tracker.go b/snow/engine/common/bootstrap_tracker.go index bd2ef43cf1f..ade9aee354e 100644 --- a/snow/engine/common/bootstrap_tracker.go +++ b/snow/engine/common/bootstrap_tracker.go @@ -14,5 +14,7 @@ type BootstrapTracker interface { // Bootstrapped marks the named chain as being bootstrapped Bootstrapped(chainID ids.ID) - OnBootstrapCompleted() chan struct{} + // AllBootstrapped returns a channel that is closed when all chains in this + // subnet have been bootstrapped + AllBootstrapped() <-chan struct{} } diff --git a/snow/engine/common/engine.go b/snow/engine/common/engine.go index dc39504dc76..8c1e98d25b2 100644 --- a/snow/engine/common/engine.go +++ b/snow/engine/common/engine.go @@ -416,9 +416,6 @@ type InternalHandler interface { // Notify this engine of peer changes. validators.Connector - // Notify this engine that a registered timeout has fired. - Timeout(context.Context) error - // Gossip to the network a container on the accepted frontier Gossip(context.Context) error diff --git a/snow/engine/common/no_ops_handlers.go b/snow/engine/common/no_ops_handlers.go index d728e101eb1..042e686d422 100644 --- a/snow/engine/common/no_ops_handlers.go +++ b/snow/engine/common/no_ops_handlers.go @@ -340,14 +340,6 @@ func (nop *noOpInternalHandler) Disconnected(_ context.Context, nodeID ids.NodeI return nil } -func (nop *noOpInternalHandler) Timeout(context.Context) error { - nop.log.Debug("dropping request", - zap.String("reason", "unhandled by this gear"), - zap.Stringer("messageOp", message.TimeoutOp), - ) - return nil -} - func (nop *noOpInternalHandler) Gossip(context.Context) error { nop.log.Debug("dropping request", zap.String("reason", "unhandled by this gear"), diff --git a/snow/engine/common/timer.go b/snow/engine/common/timer.go index 432bb9170cc..dcddfcbd2fc 100644 --- a/snow/engine/common/timer.go +++ b/snow/engine/common/timer.go @@ -3,12 +3,109 @@ package common -import "time" +import ( + "sync" + "time" +) -// Timer describes the standard interface for specifying a timeout -type Timer interface { - // RegisterTimeout specifies how much time to delay the next timeout message - // by. If the subnet has been bootstrapped, the timeout will fire - // immediately. +// PreemptionSignal signals when to preempt the pendingTimeoutToken of the timeout handler. +type PreemptionSignal struct { + activateOnce sync.Once + initOnce sync.Once + signal chan struct{} +} + +func (ps *PreemptionSignal) init() { + ps.signal = make(chan struct{}) +} + +// Listen returns a read-only channel that is closed when Preempt() is invoked. +func (ps *PreemptionSignal) Listen() <-chan struct{} { + ps.initOnce.Do(ps.init) + return ps.signal +} + +// Preempt causes any past and future calls of Listen to return a closed channel. +func (ps *PreemptionSignal) Preempt() { + ps.initOnce.Do(ps.init) + ps.activateOnce.Do(func() { + close(ps.signal) + }) +} + +// timeoutScheduler schedules timeouts to be dispatched in the future. +// Only a single timeout can be pending to be scheduled at any given time. +// Once a preemption signal is closed, all timeouts are immediately dispatched. +type timeoutScheduler struct { + newTimer func(duration time.Duration) *time.Timer + onTimeout func() + preemptionSignal <-chan struct{} + pendingTimeoutToken chan struct{} +} + +// NewTimeoutScheduler constructs a new timeout scheduler with the given function to be invoked upon a timeout, +// unless the preemptionSignal is closed and in which case it invokes the function immediately. +func NewTimeoutScheduler(onTimeout func(), preemptionSignal <-chan struct{}) *timeoutScheduler { + pendingTimout := make(chan struct{}, 1) + pendingTimout <- struct{}{} + return &timeoutScheduler{ + preemptionSignal: preemptionSignal, + newTimer: time.NewTimer, + onTimeout: onTimeout, + pendingTimeoutToken: pendingTimout, + } +} + +// RegisterTimeout fires the function the timeout scheduler is initialized with no later than the given timeout. +func (th *timeoutScheduler) RegisterTimeout(d time.Duration) { + // There can only be a single timeout pending at any time, and once a timeout is scheduled, + // we prevent future timeouts to be scheduled until the timeout triggers by taking the pendingTimeoutToken. + // Any subsequent attempt to register a timeout would fail obtaining the pendingTimeoutToken, + // and return. + if !th.acquirePendingTimeoutToken() { + return + } + + go th.scheduleTimeout(d) +} + +func (th *timeoutScheduler) scheduleTimeout(d time.Duration) { + timer := th.newTimer(d) + defer timer.Stop() + + select { + case <-timer.C: + case <-th.preemptionSignal: + } + + // Relinquish the pendingTimeoutToken. + // This is needed to be done before onTimeout() is invoked, + // and that's why onTimeout() is deferred to be called at the end of the function. + // If we trigger the timeout prematurely before we relinquish the pendingTimeoutToken, + // A subsequent timeout scheduling attempt that originates from the triggering of the current timeout + // will fail, as the pendingTimeoutToken is not yet available. + th.pendingTimeoutToken <- struct{}{} + + th.onTimeout() +} + +func (th *timeoutScheduler) acquirePendingTimeoutToken() bool { + select { + case <-th.pendingTimeoutToken: + return true + default: + return false + } +} + +// TimeoutRegistrar describes the standard interface for specifying a timeout +type TimeoutRegistrar interface { + // RegisterTimeout specifies how much time to delay the next timeout message by. + // + // If there is already a pending timeout message, this call is a no-op. + // However, it is guaranteed that the timeout will fire at least once after + // calling this function. + // + // If the subnet has been bootstrapped, the timeout will fire immediately via calling Preempt(). RegisterTimeout(time.Duration) } diff --git a/snow/engine/common/timer_test.go b/snow/engine/common/timer_test.go new file mode 100644 index 00000000000..49d8a2c362c --- /dev/null +++ b/snow/engine/common/timer_test.go @@ -0,0 +1,115 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package common + +import ( + "sync" + "testing" + "time" +) + +func TestTimeoutScheduler(t *testing.T) { + for _, testCase := range []struct { + expectedInvocationCount int + desc string + shouldPreempt bool + clock chan time.Time + initClock func(chan time.Time) + advanceTime func(chan time.Time) + }{ + { + desc: "multiple pendingTimeoutToken one after the other with preemption", + expectedInvocationCount: 10, + shouldPreempt: true, + clock: make(chan time.Time, 1), + initClock: func(chan time.Time) {}, + advanceTime: func(chan time.Time) {}, + }, + { + desc: "multiple pendingTimeoutToken one after the other", + expectedInvocationCount: 10, + clock: make(chan time.Time, 1), + initClock: func(clock chan time.Time) { + clock <- time.Now() + }, + advanceTime: func(clock chan time.Time) { + clock <- time.Now() + }, + }, + } { + t.Run(testCase.desc, func(*testing.T) { + // Not enough invocations means the test would stall. + // Too many invocations means a negative counter panic. + var wg sync.WaitGroup + wg.Add(testCase.expectedInvocationCount) + + testCase.initClock(testCase.clock) + + var preemptionSignal PreemptionSignal + ps := preemptionSignal.Listen() + + if testCase.shouldPreempt { + preemptionSignal.Preempt() + } + + // Order enforces timeouts to be registered once after another, + // in order to make the tests deterministic. + order := make(chan struct{}) + + newTimer := makeMockedTimer(testCase.clock) + + onTimeout := func() { + order <- struct{}{} + wg.Done() + testCase.advanceTime(testCase.clock) + } + + ts := NewTimeoutScheduler(onTimeout, ps) + ts.newTimer = newTimer + + for i := 0; i < testCase.expectedInvocationCount; i++ { + ts.RegisterTimeout(time.Hour) + <-order + } + + wg.Wait() + }) + } +} + +func TestTimeoutSchedulerConcurrentRegister(*testing.T) { + // Not enough invocations means the test would stall. + // Too many invocations means a negative counter panic. + + clock := make(chan time.Time, 2) + newTimer := makeMockedTimer(clock) + + var wg sync.WaitGroup + wg.Add(1) + + preemptChan := make(<-chan struct{}) + + ts := NewTimeoutScheduler(wg.Done, preemptChan) + ts.newTimer = newTimer + + ts.RegisterTimeout(time.Hour) // First timeout is registered + ts.RegisterTimeout(time.Hour) // Second should not + + // Clock ticks are after registering, in order to ensure onTimeout() isn't fired until second registration is invoked. + clock <- time.Now() + clock <- time.Now() + + wg.Wait() +} + +func makeMockedTimer(clock chan time.Time) func(time.Duration) *time.Timer { + return func(time.Duration) *time.Timer { + // We use a duration of 0 to not leave a lingering timer + // after the test finishes. + // Then we replace the time channel to have control over the timer. + timer := time.NewTimer(0) + timer.C = clock + return timer + } +} diff --git a/snow/engine/common/traced_engine.go b/snow/engine/common/traced_engine.go index a6ab2dee187..9c92a01e25b 100644 --- a/snow/engine/common/traced_engine.go +++ b/snow/engine/common/traced_engine.go @@ -329,13 +329,6 @@ func (e *tracedEngine) Disconnected(ctx context.Context, nodeID ids.NodeID) erro return e.engine.Disconnected(ctx, nodeID) } -func (e *tracedEngine) Timeout(ctx context.Context) error { - ctx, span := e.tracer.Start(ctx, "tracedEngine.Timeout") - defer span.End() - - return e.engine.Timeout(ctx) -} - func (e *tracedEngine) Gossip(ctx context.Context) error { ctx, span := e.tracer.Start(ctx, "tracedEngine.Gossip") defer span.End() diff --git a/snow/engine/enginetest/bootstrap_tracker.go b/snow/engine/enginetest/bootstrap_tracker.go index 481e28d1036..cce54ebf43c 100644 --- a/snow/engine/enginetest/bootstrap_tracker.go +++ b/snow/engine/enginetest/bootstrap_tracker.go @@ -54,7 +54,7 @@ func (s *BootstrapTracker) Bootstrapped(chainID ids.ID) { } } -func (s *BootstrapTracker) OnBootstrapCompleted() chan struct{} { +func (s *BootstrapTracker) AllBootstrapped() <-chan struct{} { if s.OnBootstrapCompletedF != nil { return s.OnBootstrapCompletedF() } else if s.CantOnBootstrapCompleted && s.T != nil { diff --git a/snow/engine/enginetest/engine.go b/snow/engine/enginetest/engine.go index 4bafbfd27b8..1d2a4f6a4a4 100644 --- a/snow/engine/enginetest/engine.go +++ b/snow/engine/enginetest/engine.go @@ -19,7 +19,6 @@ import ( ) var ( - errTimeout = errors.New("unexpectedly called Timeout") errGossip = errors.New("unexpectedly called Gossip") errNotify = errors.New("unexpectedly called Notify") errGetStateSummaryFrontier = errors.New("unexpectedly called GetStateSummaryFrontier") @@ -189,19 +188,6 @@ func (e *Engine) Start(ctx context.Context, startReqID uint32) error { return errStart } -func (e *Engine) Timeout(ctx context.Context) error { - if e.TimeoutF != nil { - return e.TimeoutF(ctx) - } - if !e.CantTimeout { - return nil - } - if e.T != nil { - require.FailNow(e.T, errTimeout.Error()) - } - return errTimeout -} - func (e *Engine) Gossip(ctx context.Context) error { if e.GossipF != nil { return e.GossipF(ctx) diff --git a/snow/engine/enginetest/timer.go b/snow/engine/enginetest/timer.go index f2161b5c8c1..9eaff381cc4 100644 --- a/snow/engine/enginetest/timer.go +++ b/snow/engine/enginetest/timer.go @@ -8,12 +8,8 @@ import ( "time" "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/snow/engine/common" ) -var _ common.Timer = (*Timer)(nil) - // Timer is a test timer type Timer struct { T *testing.T @@ -23,11 +19,6 @@ type Timer struct { RegisterTimeoutF func(time.Duration) } -// Default set the default callable value to [cant] -func (t *Timer) Default(cant bool) { - t.CantRegisterTimout = cant -} - func (t *Timer) RegisterTimeout(delay time.Duration) { if t.RegisterTimeoutF != nil { t.RegisterTimeoutF(delay) diff --git a/snow/engine/snowman/bootstrap/bootstrapper.go b/snow/engine/snowman/bootstrap/bootstrapper.go index bece6f32a1b..024925c7928 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper.go +++ b/snow/engine/snowman/bootstrap/bootstrapper.go @@ -70,7 +70,7 @@ type Bootstrapper struct { Config shouldHalt func() bool *metrics - + TimeoutRegistrar common.TimeoutRegistrar // list of NoOpsHandler for messages dropped by bootstrapper common.StateSummaryFrontierHandler common.AcceptedStateSummaryHandler @@ -119,7 +119,7 @@ type Bootstrapper struct { func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) error) (*Bootstrapper, error) { metrics, err := newMetrics(config.Ctx.Registerer) - return &Bootstrapper{ + bs := &Bootstrapper{ shouldHalt: config.ShouldHalt, nonVerifyingParser: config.NonVerifyingParse, Config: config, @@ -139,7 +139,19 @@ func New(config Config, onFinished func(ctx context.Context, lastReqID uint32) e executedStateTransitions: math.MaxInt, onFinished: onFinished, - }, err + } + + timeout := func() { + config.Ctx.Lock.Lock() + defer config.Ctx.Lock.Unlock() + + if err := bs.Timeout(); err != nil { + bs.Config.Ctx.Log.Warn("Encountered error during bootstrapping: %w", zap.Error(err)) + } + } + bs.TimeoutRegistrar = common.NewTimeoutScheduler(timeout, config.BootstrapTracker.AllBootstrapped()) + + return bs, err } func (b *Bootstrapper) Context() *snow.ConsensusContext { @@ -703,8 +715,8 @@ func (b *Bootstrapper) tryStartExecuting(ctx context.Context) error { log("waiting for the remaining chains in this subnet to finish syncing") // Restart bootstrapping after [bootstrappingDelay] to keep up to date // on the latest tip. - b.Config.Timer.RegisterTimeout(bootstrappingDelay) b.awaitingTimeout = true + b.TimeoutRegistrar.RegisterTimeout(bootstrappingDelay) return nil } return b.onFinished(ctx, b.requestID) @@ -722,16 +734,16 @@ func (b *Bootstrapper) getLastAccepted(ctx context.Context) (snowman.Block, erro return lastAccepted, nil } -func (b *Bootstrapper) Timeout(ctx context.Context) error { +func (b *Bootstrapper) Timeout() error { if !b.awaitingTimeout { return errUnexpectedTimeout } b.awaitingTimeout = false if !b.Config.BootstrapTracker.IsBootstrapped() { - return b.restartBootstrapping(ctx) + return b.restartBootstrapping(context.TODO()) } - return b.onFinished(ctx, b.requestID) + return b.onFinished(context.TODO(), b.requestID) } func (b *Bootstrapper) restartBootstrapping(ctx context.Context) error { diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 14024dc65fb..772cf51281e 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -103,7 +103,6 @@ func newConfig(t *testing.T) (Config, ids.NodeID, *enginetest.Sender, *blocktest PeerTracker: peerTracker, Sender: sender, BootstrapTracker: bootstrapTracker, - Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: memdb.New(), VM: vm, @@ -155,7 +154,6 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { PeerTracker: peerTracker, Sender: sender, BootstrapTracker: &enginetest.BootstrapTracker{}, - Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: memdb.New(), VM: vm, @@ -180,6 +178,7 @@ func TestBootstrapperStartsOnlyIfEnoughStakeIsConnected(t *testing.T) { } bs, err := New(cfg, dummyCallback) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} vm.CantSetState = false vm.CantConnected = true @@ -236,6 +235,7 @@ func TestBootstrapperSingleFrontier(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -264,6 +264,7 @@ func TestBootstrapperUnknownByzantineResponse(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -309,6 +310,7 @@ func TestBootstrapperPartialFetch(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -359,6 +361,7 @@ func TestBootstrapperEmptyResponse(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -407,6 +410,7 @@ func TestBootstrapperAncestors(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -452,6 +456,7 @@ func TestBootstrapperFinalized(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -494,6 +499,7 @@ func TestRestartBootstrapping(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -558,6 +564,7 @@ func TestBootstrapOldBlockAfterStateSync(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -598,6 +605,7 @@ func TestBootstrapContinueAfterHalt(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} getBlockF := vm.GetBlockF vm.GetBlockF = func(ctx context.Context, blkID ids.ID) (snowman.Block, error) { @@ -690,7 +698,6 @@ func TestBootstrapNoParseOnNew(t *testing.T) { PeerTracker: peerTracker, Sender: sender, BootstrapTracker: bootstrapTracker, - Timer: &enginetest.Timer{}, AncestorsMaxContainersReceived: 2000, DB: intervalDB, VM: vm, @@ -728,6 +735,7 @@ func TestBootstrapperReceiveStaleAncestorsMessage(t *testing.T) { }, ) require.NoError(err) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(bs.Start(context.Background(), 0)) @@ -772,6 +780,7 @@ func TestBootstrapperRollbackOnSetState(t *testing.T) { return nil }, ) + bs.TimeoutRegistrar = &enginetest.Timer{} require.NoError(err) vm.SetStateF = func(context.Context, snow.State) error { diff --git a/snow/engine/snowman/bootstrap/config.go b/snow/engine/snowman/bootstrap/config.go index 1211e68ebbb..d5aa9e011d4 100644 --- a/snow/engine/snowman/bootstrap/config.go +++ b/snow/engine/snowman/bootstrap/config.go @@ -23,7 +23,6 @@ type Config struct { StartupTracker tracker.Startup Sender common.Sender BootstrapTracker common.BootstrapTracker - Timer common.Timer // PeerTracker manages the set of nodes that we fetch the next block from. PeerTracker *p2p.PeerTracker diff --git a/snow/engine/snowman/engine.go b/snow/engine/snowman/engine.go index b1c1698b53d..2e454402dad 100644 --- a/snow/engine/snowman/engine.go +++ b/snow/engine/snowman/engine.go @@ -431,10 +431,6 @@ func (e *Engine) QueryFailed(ctx context.Context, nodeID ids.NodeID, requestID u return e.executeDeferredWork(ctx) } -func (*Engine) Timeout(context.Context) error { - return nil -} - func (e *Engine) Shutdown(ctx context.Context) error { e.Ctx.Log.Info("shutting down consensus engine") diff --git a/snow/networking/handler/handler.go b/snow/networking/handler/handler.go index 224e0abc51b..d25d6b8a149 100644 --- a/snow/networking/handler/handler.go +++ b/snow/networking/handler/handler.go @@ -49,7 +49,6 @@ var ( ) type Handler interface { - common.Timer health.Checker Context() *snow.ConsensusContext @@ -90,7 +89,6 @@ type handler struct { validators validators.Manager // Receives messages from the VM msgFromVMChan <-chan common.Message - preemptTimeouts chan struct{} gossipFrequency time.Duration engineManager *EngineManager @@ -110,7 +108,6 @@ type handler struct { asyncMessageQueue MessageQueue // Worker pool for handling asynchronous consensus messages asyncMessagePool errgroup.Group - timeouts chan struct{} closeOnce sync.Once startClosingTime time.Time @@ -147,9 +144,7 @@ func New( ctx: ctx, validators: validators, msgFromVMChan: msgFromVMChan, - preemptTimeouts: subnet.OnBootstrapCompleted(), gossipFrequency: gossipFrequency, - timeouts: make(chan struct{}, 1), closingChan: make(chan struct{}), closed: make(chan struct{}), resourceTracker: resourceTracker, @@ -297,26 +292,6 @@ func (h *handler) Len() int { return h.syncMessageQueue.Len() + h.asyncMessageQueue.Len() } -func (h *handler) RegisterTimeout(d time.Duration) { - go func() { - timer := time.NewTimer(d) - defer timer.Stop() - - select { - case <-timer.C: - case <-h.preemptTimeouts: - } - - // If there is already a timeout ready to fire - just drop the - // additional timeout. This ensures that all goroutines that are spawned - // here are able to close if the chain is shutdown. - select { - case h.timeouts <- struct{}{}: - default: - } - }() -} - // Note: It is possible for Stop to be called before/concurrently with Start. // // Invariant: Stop must never block. @@ -421,9 +396,6 @@ func (h *handler) dispatchChans(ctx context.Context) { case <-gossiper.C: msg = message.InternalGossipRequest(h.ctx.NodeID) - - case <-h.timeouts: - msg = message.InternalTimeout(h.ctx.NodeID) } if err := h.handleChanMsg(msg); err != nil { @@ -936,9 +908,6 @@ func (h *handler) handleChanMsg(msg message.InboundMessage) error { case *message.GossipRequest: return engine.Gossip(context.TODO()) - case *message.Timeout: - return engine.Timeout(context.TODO()) - default: return fmt.Errorf( "attempt to submit unhandled chan msg %s", diff --git a/snow/networking/handler/handlermock/handler.go b/snow/networking/handler/handlermock/handler.go index d8e3c23bb0f..71d58ce7d09 100644 --- a/snow/networking/handler/handlermock/handler.go +++ b/snow/networking/handler/handlermock/handler.go @@ -127,18 +127,6 @@ func (mr *HandlerMockRecorder) Push(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*Handler)(nil).Push), arg0, arg1) } -// RegisterTimeout mocks base method. -func (m *Handler) RegisterTimeout(arg0 time.Duration) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterTimeout", arg0) -} - -// RegisterTimeout indicates an expected call of RegisterTimeout. -func (mr *HandlerMockRecorder) RegisterTimeout(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterTimeout", reflect.TypeOf((*Handler)(nil).RegisterTimeout), arg0) -} - // SetEngineManager mocks base method. func (m *Handler) SetEngineManager(arg0 *handler.EngineManager) { m.ctrl.T.Helper() diff --git a/subnets/subnet.go b/subnets/subnet.go index 95425ba3050..296be810123 100644 --- a/subnets/subnet.go +++ b/subnets/subnet.go @@ -34,23 +34,25 @@ type Subnet interface { } type subnet struct { - lock sync.RWMutex - bootstrapping set.Set[ids.ID] - bootstrapped set.Set[ids.ID] - once sync.Once - bootstrappedSema chan struct{} - config Config - myNodeID ids.NodeID + lock sync.RWMutex + bootstrapping set.Set[ids.ID] + bootstrapped set.Set[ids.ID] + config Config + myNodeID ids.NodeID + bootstrapSignal common.PreemptionSignal } func New(myNodeID ids.NodeID, config Config) Subnet { return &subnet{ - bootstrappedSema: make(chan struct{}), - config: config, - myNodeID: myNodeID, + config: config, + myNodeID: myNodeID, } } +func (s *subnet) AllBootstrapped() <-chan struct{} { + return s.bootstrapSignal.Listen() +} + func (s *subnet) IsBootstrapped() bool { s.lock.RLock() defer s.lock.RUnlock() @@ -68,13 +70,7 @@ func (s *subnet) Bootstrapped(chainID ids.ID) { return } - s.once.Do(func() { - close(s.bootstrappedSema) - }) -} - -func (s *subnet) OnBootstrapCompleted() chan struct{} { - return s.bootstrappedSema + s.bootstrapSignal.Preempt() } func (s *subnet) AddChain(chainID ids.ID) bool { diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index 2d1740a650e..7048778b19b 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -1386,6 +1386,7 @@ func TestBootstrapPartiallyAccepted(t *testing.T) { engine.Start, ) require.NoError(err) + bootstrapper.TimeoutRegistrar = &enginetest.Timer{} h.SetEngineManager(&handler.EngineManager{ Avalanche: &handler.Engine{ From 44d21785fedae3540c223fd60dcceecde3a1f168 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Mon, 21 Oct 2024 11:43:26 -0400 Subject: [PATCH 02/20] Wrap `TestDiffExpiry` sub-tests in `t.Run` (#3483) --- vms/platformvm/state/diff_test.go | 98 ++++++++++++++++--------------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index 013a918b60f..3625986d780 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -201,64 +201,66 @@ func TestDiffExpiry(t *testing.T) { } for _, test := range tests { - require := require.New(t) + t.Run(test.name, func(t *testing.T) { + require := require.New(t) - state := newTestState(t, memdb.New()) - for _, expiry := range test.initialExpiries { - state.PutExpiry(expiry) - } - - d, err := NewDiffOn(state) - require.NoError(err) - - var ( - expectedExpiries = set.Of(test.initialExpiries...) - unexpectedExpiries set.Set[ExpiryEntry] - ) - for _, op := range test.ops { - if op.put { - d.PutExpiry(op.entry) - expectedExpiries.Add(op.entry) - unexpectedExpiries.Remove(op.entry) - } else { - d.DeleteExpiry(op.entry) - expectedExpiries.Remove(op.entry) - unexpectedExpiries.Add(op.entry) + state := newTestState(t, memdb.New()) + for _, expiry := range test.initialExpiries { + state.PutExpiry(expiry) } - } - // If expectedExpiries is empty, we want expectedExpiriesSlice to be - // nil. - var expectedExpiriesSlice []ExpiryEntry - if expectedExpiries.Len() > 0 { - expectedExpiriesSlice = expectedExpiries.List() - utils.Sort(expectedExpiriesSlice) - } - - verifyChain := func(chain Chain) { - expiryIterator, err := chain.GetExpiryIterator() + d, err := NewDiffOn(state) require.NoError(err) - require.Equal( - expectedExpiriesSlice, - iterator.ToSlice(expiryIterator), + + var ( + expectedExpiries = set.Of(test.initialExpiries...) + unexpectedExpiries set.Set[ExpiryEntry] ) + for _, op := range test.ops { + if op.put { + d.PutExpiry(op.entry) + expectedExpiries.Add(op.entry) + unexpectedExpiries.Remove(op.entry) + } else { + d.DeleteExpiry(op.entry) + expectedExpiries.Remove(op.entry) + unexpectedExpiries.Add(op.entry) + } + } - for expiry := range expectedExpiries { - has, err := chain.HasExpiry(expiry) - require.NoError(err) - require.True(has) + // If expectedExpiries is empty, we want expectedExpiriesSlice to be + // nil. + var expectedExpiriesSlice []ExpiryEntry + if expectedExpiries.Len() > 0 { + expectedExpiriesSlice = expectedExpiries.List() + utils.Sort(expectedExpiriesSlice) } - for expiry := range unexpectedExpiries { - has, err := chain.HasExpiry(expiry) + + verifyChain := func(chain Chain) { + expiryIterator, err := chain.GetExpiryIterator() require.NoError(err) - require.False(has) + require.Equal( + expectedExpiriesSlice, + iterator.ToSlice(expiryIterator), + ) + + for expiry := range expectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.True(has) + } + for expiry := range unexpectedExpiries { + has, err := chain.HasExpiry(expiry) + require.NoError(err) + require.False(has) + } } - } - verifyChain(d) - require.NoError(d.Apply(state)) - verifyChain(state) - assertChainsEqual(t, d, state) + verifyChain(d) + require.NoError(d.Apply(state)) + verifyChain(state) + assertChainsEqual(t, d, state) + }) } } From 58d5b8ddf85fe3a4a9a9b0739e6d3844e6d1ea9f Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Mon, 21 Oct 2024 18:42:17 -0400 Subject: [PATCH 03/20] Populate BLS key diffs for subnet validators --- vms/platformvm/state/state.go | 317 ++++++++++++++++++++-------------- 1 file changed, 186 insertions(+), 131 deletions(-) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 58c2056570b..a34b587a673 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -59,8 +59,9 @@ const ( var ( _ State = (*state)(nil) - errValidatorSetAlreadyPopulated = errors.New("validator set already populated") - errIsNotSubnet = errors.New("is not a subnet") + errValidatorSetAlreadyPopulated = errors.New("validator set already populated") + errIsNotSubnet = errors.New("is not a subnet") + errMissingPrimaryNetworkValidator = errors.New("missing primary network validator") BlockIDPrefix = []byte("blockID") BlockPrefix = []byte("block") @@ -2007,164 +2008,218 @@ func (s *state) writeExpiry() error { func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecVersion uint16) error { for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { + // We must write the primary network stakers last because writing subnet + // validator diffs may depend on the primary network validator diffs to + // inherit the public keys. + if subnetID == constants.PrimaryNetworkID { + continue + } + delete(s.currentStakers.validatorDiffs, subnetID) - // Select db to write to - validatorDB := s.currentSubnetValidatorList - delegatorDB := s.currentSubnetDelegatorList - if subnetID == constants.PrimaryNetworkID { - validatorDB = s.currentValidatorList - delegatorDB = s.currentDelegatorList + err := s.writeCurrentStakersSubnetDiff( + subnetID, + validatorDiffs, + updateValidators, + height, + codecVersion, + ) + if err != nil { + return err } + } - // Record the change in weight and/or public key for each validator. - for nodeID, validatorDiff := range validatorDiffs { - // Copy [nodeID] so it doesn't get overwritten next iteration. - nodeID := nodeID + if validatorDiffs, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID]; ok { + delete(s.currentStakers.validatorDiffs, constants.PrimaryNetworkID) - weightDiff := &ValidatorWeightDiff{ - Decrease: validatorDiff.validatorStatus == deleted, - } - switch validatorDiff.validatorStatus { - case added: - staker := validatorDiff.validator - weightDiff.Amount = staker.Weight - - // Invariant: Only the Primary Network contains non-nil public - // keys. - if staker.PublicKey != nil { - // Record that the public key for the validator is being - // added. This means the prior value for the public key was - // nil. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(constants.PrimaryNetworkID, height, nodeID), - nil, - ) - if err != nil { - return err - } - } + err := s.writeCurrentStakersSubnetDiff( + constants.PrimaryNetworkID, + validatorDiffs, + updateValidators, + height, + codecVersion, + ) + if err != nil { + return err + } + } - // The validator is being added. - // - // Invariant: It's impossible for a delegator to have been - // rewarded in the same block that the validator was added. - startTime := uint64(staker.StartTime.Unix()) - metadata := &validatorMetadata{ - txID: staker.TxID, - lastUpdated: staker.StartTime, - - UpDuration: 0, - LastUpdated: startTime, - StakerStartTime: startTime, - PotentialReward: staker.PotentialReward, - PotentialDelegateeReward: 0, - } + // TODO: Move validator set management out of the state package + // + // Attempt to update the stake metrics + if !updateValidators { + return nil + } - metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) - if err != nil { - return fmt.Errorf("failed to serialize current validator: %w", err) - } + totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) + if err != nil { + return fmt.Errorf("failed to get total weight of primary network: %w", err) + } - if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { - return fmt.Errorf("failed to write current validator to list: %w", err) - } + s.metrics.SetLocalStake(s.validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) + s.metrics.SetTotalStake(totalWeight) + return nil +} - s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) - case deleted: - staker := validatorDiff.validator - weightDiff.Amount = staker.Weight - - // Invariant: Only the Primary Network contains non-nil public - // keys. - if staker.PublicKey != nil { - // Record that the public key for the validator is being - // removed. This means we must record the prior value of the - // public key. - // - // Note: We store the uncompressed public key here as it is - // significantly more efficient to parse when applying - // diffs. - err := s.validatorPublicKeyDiffsDB.Put( - marshalDiffKey(constants.PrimaryNetworkID, height, nodeID), - bls.PublicKeyToUncompressedBytes(staker.PublicKey), - ) - if err != nil { - return err - } - } +func (s *state) writeCurrentStakersSubnetDiff( + subnetID ids.ID, + validatorDiffs map[ids.NodeID]*diffValidator, + updateValidators bool, + height uint64, + codecVersion uint16, +) error { + // Select db to write to + validatorDB := s.currentSubnetValidatorList + delegatorDB := s.currentSubnetDelegatorList + if subnetID == constants.PrimaryNetworkID { + validatorDB = s.currentValidatorList + delegatorDB = s.currentDelegatorList + } - if err := validatorDB.Delete(staker.TxID[:]); err != nil { - return fmt.Errorf("failed to delete current staker: %w", err) + // Record the change in weight and/or public key for each validator. + for nodeID, validatorDiff := range validatorDiffs { + var ( + staker *Staker + pk *bls.PublicKey + weightDiff = &ValidatorWeightDiff{ + Decrease: validatorDiff.validatorStatus == deleted, + } + ) + if validatorDiff.validatorStatus != unmodified { + staker = validatorDiff.validator + + pk = staker.PublicKey + // For non-primary network validators, the public key is inherited + // from the primary network. + if subnetID != constants.PrimaryNetworkID { + if vdr, ok := s.currentStakers.validators[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is still present after + // writing. + pk = vdr.validator.PublicKey + } else if vdr, ok := s.currentStakers.validatorDiffs[constants.PrimaryNetworkID][nodeID]; ok && vdr.validator != nil { + // The primary network validator is being removed during + // writing. + pk = vdr.validator.PublicKey + } else { + // This should never happen. + return ErrMissingPrimaryNetworkValidator } - - s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) } - err := writeCurrentDelegatorDiff( - delegatorDB, - weightDiff, - validatorDiff, - codecVersion, - ) - if err != nil { - return err + weightDiff.Amount = staker.Weight + } + + switch validatorDiff.validatorStatus { + case added: + if pk != nil { + // Record that the public key for the validator is being added. + // This means the prior value for the public key was nil. + err := s.validatorPublicKeyDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + nil, + ) + if err != nil { + return err + } } - if weightDiff.Amount == 0 { - // No weight change to record; go to next validator. - continue + // The validator is being added. + // + // Invariant: It's impossible for a delegator to have been rewarded + // in the same block that the validator was added. + startTime := uint64(staker.StartTime.Unix()) + metadata := &validatorMetadata{ + txID: staker.TxID, + lastUpdated: staker.StartTime, + + UpDuration: 0, + LastUpdated: startTime, + StakerStartTime: startTime, + PotentialReward: staker.PotentialReward, + PotentialDelegateeReward: 0, } - err = s.validatorWeightDiffsDB.Put( - marshalDiffKey(subnetID, height, nodeID), - marshalWeightDiff(weightDiff), - ) + metadataBytes, err := MetadataCodec.Marshal(codecVersion, metadata) if err != nil { - return err + return fmt.Errorf("failed to serialize current validator: %w", err) } - // TODO: Move the validator set management out of the state package - if !updateValidators { - continue + if err = validatorDB.Put(staker.TxID[:], metadataBytes); err != nil { + return fmt.Errorf("failed to write current validator to list: %w", err) } - if weightDiff.Decrease { - err = s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) - } else { - if validatorDiff.validatorStatus == added { - staker := validatorDiff.validator - err = s.validators.AddStaker( - subnetID, - nodeID, - staker.PublicKey, - staker.TxID, - weightDiff.Amount, - ) - } else { - err = s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount) + s.validatorState.LoadValidatorMetadata(nodeID, subnetID, metadata) + case deleted: + if pk != nil { + // Record that the public key for the validator is being + // removed. This means we must record the prior value of the + // public key. + // + // Note: We store the uncompressed public key here as it is + // significantly more efficient to parse when applying diffs. + err := s.validatorPublicKeyDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + bls.PublicKeyToUncompressedBytes(pk), + ) + if err != nil { + return err } } - if err != nil { - return fmt.Errorf("failed to update validator weight: %w", err) + + if err := validatorDB.Delete(staker.TxID[:]); err != nil { + return fmt.Errorf("failed to delete current staker: %w", err) } + + s.validatorState.DeleteValidatorMetadata(nodeID, subnetID) } - } - // TODO: Move validator set management out of the state package - // - // Attempt to update the stake metrics - if !updateValidators { - return nil - } + err := writeCurrentDelegatorDiff( + delegatorDB, + weightDiff, + validatorDiff, + codecVersion, + ) + if err != nil { + return err + } - totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) - if err != nil { - return fmt.Errorf("failed to get total weight of primary network: %w", err) - } + if weightDiff.Amount == 0 { + // No weight change to record; go to next validator. + continue + } - s.metrics.SetLocalStake(s.validators.GetWeight(constants.PrimaryNetworkID, s.ctx.NodeID)) - s.metrics.SetTotalStake(totalWeight) + err = s.validatorWeightDiffsDB.Put( + marshalDiffKey(subnetID, height, nodeID), + marshalWeightDiff(weightDiff), + ) + if err != nil { + return err + } + + // TODO: Move the validator set management out of the state package + if !updateValidators { + continue + } + + if weightDiff.Decrease { + err = s.validators.RemoveWeight(subnetID, nodeID, weightDiff.Amount) + } else { + if validatorDiff.validatorStatus == added { + err = s.validators.AddStaker( + subnetID, + nodeID, + pk, + staker.TxID, + weightDiff.Amount, + ) + } else { + err = s.validators.AddWeight(subnetID, nodeID, weightDiff.Amount) + } + } + if err != nil { + return fmt.Errorf("failed to update validator weight: %w", err) + } + } return nil } From 273fbbecd3da929db6d00bab40511884efe5ee0c Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Mon, 21 Oct 2024 18:50:47 -0400 Subject: [PATCH 04/20] Populate BLS key diffs for subnet validators --- vms/platformvm/state/state.go | 12 +- vms/platformvm/state/state_test.go | 351 +++++++++++++++++---------- vms/platformvm/validators/manager.go | 8 +- 3 files changed, 241 insertions(+), 130 deletions(-) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index a34b587a673..bed7b4746fb 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -190,6 +190,7 @@ type State interface { validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error SetHeight(height uint64) @@ -1244,10 +1245,11 @@ func (s *state) ApplyValidatorPublicKeyDiffs( validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error { diffIter := s.validatorPublicKeyDiffsDB.NewIteratorWithStartAndPrefix( - marshalStartDiffKey(constants.PrimaryNetworkID, startHeight), - constants.PrimaryNetworkID[:], + marshalStartDiffKey(subnetID, startHeight), + subnetID[:], ) defer diffIter.Release() @@ -2101,8 +2103,10 @@ func (s *state) writeCurrentStakersSubnetDiff( // writing. pk = vdr.validator.PublicKey } else { - // This should never happen. - return ErrMissingPrimaryNetworkValidator + // This should never happen as the primary network diffs are + // written last and subnet validator times must be a subset + // of the primary network validator times. + return errMissingPrimaryNetworkValidator } } diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index d2912950043..ba8dbd65481 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -28,6 +28,7 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -996,35 +997,43 @@ func TestStateAddRemoveValidator(t *testing.T) { state := newTestState(t, memdb.New()) var ( - numNodes = 3 - subnetID = ids.GenerateTestID() - startTime = time.Now() - endTime = startTime.Add(24 * time.Hour) - stakers = make([]Staker, numNodes) + numNodes = 5 + subnetID = ids.GenerateTestID() + startTime = time.Now() + endTime = startTime.Add(24 * time.Hour) + primaryStakers = make([]Staker, numNodes) + subnetStakers = make([]Staker, numNodes) ) - for i := 0; i < numNodes; i++ { - stakers[i] = Staker{ + for i := range primaryStakers { + sk, err := bls.NewSecretKey() + require.NoError(err) + + primaryStakers[i] = Staker{ TxID: ids.GenerateTestID(), NodeID: ids.GenerateTestNodeID(), + PublicKey: bls.PublicFromSecretKey(sk), + SubnetID: constants.PrimaryNetworkID, Weight: uint64(i + 1), StartTime: startTime.Add(time.Duration(i) * time.Second), EndTime: endTime.Add(time.Duration(i) * time.Second), PotentialReward: uint64(i + 1), } - if i%2 == 0 { - stakers[i].SubnetID = subnetID - } else { - sk, err := bls.NewSecretKey() - require.NoError(err) - stakers[i].PublicKey = bls.PublicFromSecretKey(sk) - stakers[i].SubnetID = constants.PrimaryNetworkID + } + for i, primaryStaker := range primaryStakers { + subnetStakers[i] = Staker{ + TxID: ids.GenerateTestID(), + NodeID: primaryStaker.NodeID, + PublicKey: nil, // Key is inherited from the primary network + SubnetID: subnetID, + Weight: uint64(i + 1), + StartTime: primaryStaker.StartTime, + EndTime: primaryStaker.EndTime, + PotentialReward: uint64(i + 1), } } type diff struct { addedValidators []Staker - addedDelegators []Staker - removedDelegators []Staker removedValidators []Staker expectedPrimaryValidatorSet map[ids.NodeID]*validators.GetValidatorOutput @@ -1037,101 +1046,176 @@ func TestStateAddRemoveValidator(t *testing.T) { expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Add a subnet validator - addedValidators: []Staker{stakers[0]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[0].NodeID: { - NodeID: stakers[0].NodeID, - Weight: stakers[0].Weight, + // Add primary validator 0 + addedValidators: []Staker{primaryStakers[0]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, }, }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Remove a subnet validator - removedValidators: []Staker{stakers[0]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + // Add subnet validator 0 + addedValidators: []Staker{subnetStakers[0]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, + }, + }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + subnetStakers[0].NodeID: { + NodeID: subnetStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: subnetStakers[0].Weight, + }, + }, }, - { // Add a primary network validator - addedValidators: []Staker{stakers[1]}, + { + // Remove subnet validator 0 + removedValidators: []Staker{subnetStakers[0]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Do nothing + // Add primary network validator 1, and subnet validator 1 + addedValidators: []Staker{primaryStakers[1], subnetStakers[1]}, + // Remove primary network validator 0, and subnet validator 1 + removedValidators: []Staker{primaryStakers[0], subnetStakers[1]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[1].NodeID: { + NodeID: primaryStakers[1].NodeID, + PublicKey: primaryStakers[1].PublicKey, + Weight: primaryStakers[1].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, - { // Remove a primary network validator - removedValidators: []Staker{stakers[1]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + { + // Add primary network validator 2, and subnet validator 2 + addedValidators: []Staker{primaryStakers[2], subnetStakers[2]}, + // Remove primary network validator 1 + removedValidators: []Staker{primaryStakers[1]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[2].NodeID: { + NodeID: primaryStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: primaryStakers[2].Weight, + }, + }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + subnetStakers[2].NodeID: { + NodeID: subnetStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: subnetStakers[2].Weight, + }, + }, }, { - // Add 2 subnet validators and a primary network validator - addedValidators: []Staker{stakers[0], stakers[1], stakers[2]}, + // Add primary network and subnet validators 3 & 4 + addedValidators: []Staker{primaryStakers[3], primaryStakers[4], subnetStakers[3], subnetStakers[4]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[2].NodeID: { + NodeID: primaryStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: primaryStakers[2].Weight, + }, + primaryStakers[3].NodeID: { + NodeID: primaryStakers[3].NodeID, + PublicKey: primaryStakers[3].PublicKey, + Weight: primaryStakers[3].Weight, + }, + primaryStakers[4].NodeID: { + NodeID: primaryStakers[4].NodeID, + PublicKey: primaryStakers[4].PublicKey, + Weight: primaryStakers[4].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[0].NodeID: { - NodeID: stakers[0].NodeID, - Weight: stakers[0].Weight, + subnetStakers[2].NodeID: { + NodeID: subnetStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: subnetStakers[2].Weight, + }, + subnetStakers[3].NodeID: { + NodeID: subnetStakers[3].NodeID, + PublicKey: primaryStakers[3].PublicKey, + Weight: subnetStakers[3].Weight, }, - stakers[2].NodeID: { - NodeID: stakers[2].NodeID, - Weight: stakers[2].Weight, + subnetStakers[4].NodeID: { + NodeID: subnetStakers[4].NodeID, + PublicKey: primaryStakers[4].PublicKey, + Weight: subnetStakers[4].Weight, }, }, }, { - // Remove 2 subnet validators and a primary network validator. - removedValidators: []Staker{stakers[0], stakers[1], stakers[2]}, + // Remove primary network and subnet validators 2 & 3 & 4 + removedValidators: []Staker{ + primaryStakers[2], primaryStakers[3], primaryStakers[4], + subnetStakers[2], subnetStakers[3], subnetStakers[4], + }, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + }, + { + // Do nothing expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, } for currentIndex, diff := range diffs { - for _, added := range diff.addedValidators { - added := added - require.NoError(state.PutCurrentValidator(&added)) - } - for _, added := range diff.addedDelegators { - added := added - state.PutCurrentDelegator(&added) + d, err := NewDiffOn(state) + require.NoError(err) + + type subnetIDNodeID struct { + subnetID ids.ID + nodeID ids.NodeID } - for _, removed := range diff.removedDelegators { - removed := removed - state.DeleteCurrentDelegator(&removed) + var expectedValidators set.Set[subnetIDNodeID] + for _, added := range diff.addedValidators { + require.NoError(d.PutCurrentValidator(&added)) + + expectedValidators.Add(subnetIDNodeID{ + subnetID: added.SubnetID, + nodeID: added.NodeID, + }) } for _, removed := range diff.removedValidators { - removed := removed - state.DeleteCurrentValidator(&removed) + d.DeleteCurrentValidator(&removed) + + expectedValidators.Remove(subnetIDNodeID{ + subnetID: removed.SubnetID, + nodeID: removed.NodeID, + }) } + require.NoError(d.Apply(state)) + currentHeight := uint64(currentIndex + 1) state.SetHeight(currentHeight) require.NoError(state.Commit()) for _, added := range diff.addedValidators { + subnetNodeID := subnetIDNodeID{ + subnetID: added.SubnetID, + nodeID: added.NodeID, + } + if !expectedValidators.Contains(subnetNodeID) { + continue + } + gotValidator, err := state.GetCurrentValidator(added.SubnetID, added.NodeID) require.NoError(err) require.Equal(added, *gotValidator) @@ -1142,37 +1226,84 @@ func TestStateAddRemoveValidator(t *testing.T) { require.ErrorIs(err, database.ErrNotFound) } + primaryValidatorSet := state.validators.GetMap(constants.PrimaryNetworkID) + delete(primaryValidatorSet, defaultValidatorNodeID) // Ignore the genesis validator + require.Equal(diff.expectedPrimaryValidatorSet, primaryValidatorSet) + + require.Equal(diff.expectedSubnetValidatorSet, state.validators.GetMap(subnetID)) + for i := 0; i < currentIndex; i++ { prevDiff := diffs[i] prevHeight := uint64(i + 1) - primaryValidatorSet := copyValidatorSet(diff.expectedPrimaryValidatorSet) - require.NoError(state.ApplyValidatorWeightDiffs( - context.Background(), - primaryValidatorSet, - currentHeight, - prevHeight+1, - constants.PrimaryNetworkID, - )) - requireEqualWeightsValidatorSet(require, prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) - - require.NoError(state.ApplyValidatorPublicKeyDiffs( - context.Background(), - primaryValidatorSet, - currentHeight, - prevHeight+1, - )) - requireEqualPublicKeysValidatorSet(require, prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) - - subnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) - require.NoError(state.ApplyValidatorWeightDiffs( - context.Background(), - subnetValidatorSet, - currentHeight, - prevHeight+1, - subnetID, - )) - requireEqualWeightsValidatorSet(require, prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) + { + primaryValidatorSet := copyValidatorSet(diff.expectedPrimaryValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + primaryValidatorSet, + currentHeight, + prevHeight+1, + constants.PrimaryNetworkID, + )) + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + primaryValidatorSet, + currentHeight, + prevHeight+1, + constants.PrimaryNetworkID, + )) + require.Equal(prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) + } + + { + legacySubnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + legacySubnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) + + // Update the public keys of the subnet validators with the current + // primary network validator public keys + for nodeID, vdr := range legacySubnetValidatorSet { + if primaryVdr, ok := diff.expectedPrimaryValidatorSet[nodeID]; ok { + vdr.PublicKey = primaryVdr.PublicKey + } else { + vdr.PublicKey = nil + } + } + + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + legacySubnetValidatorSet, + currentHeight, + prevHeight+1, + constants.PrimaryNetworkID, + )) + require.Equal(prevDiff.expectedSubnetValidatorSet, legacySubnetValidatorSet) + } + + { + subnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + subnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) + + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + subnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) + require.Equal(prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) + } } } } @@ -1188,36 +1319,6 @@ func copyValidatorSet( return result } -func requireEqualWeightsValidatorSet( - require *require.Assertions, - expected map[ids.NodeID]*validators.GetValidatorOutput, - actual map[ids.NodeID]*validators.GetValidatorOutput, -) { - require.Len(actual, len(expected)) - for nodeID, expectedVdr := range expected { - require.Contains(actual, nodeID) - - actualVdr := actual[nodeID] - require.Equal(expectedVdr.NodeID, actualVdr.NodeID) - require.Equal(expectedVdr.Weight, actualVdr.Weight) - } -} - -func requireEqualPublicKeysValidatorSet( - require *require.Assertions, - expected map[ids.NodeID]*validators.GetValidatorOutput, - actual map[ids.NodeID]*validators.GetValidatorOutput, -) { - require.Len(actual, len(expected)) - for nodeID, expectedVdr := range expected { - require.Contains(actual, nodeID) - - actualVdr := actual[nodeID] - require.Equal(expectedVdr.NodeID, actualVdr.NodeID) - require.Equal(expectedVdr.PublicKey, actualVdr.PublicKey) - } -} - func TestParsedStateBlock(t *testing.T) { var ( require = require.New(t) diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index 7f1ea5ea640..142db3e7635 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -85,6 +85,7 @@ type State interface { validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight uint64, endHeight uint64, + subnetID ids.ID, ) error } @@ -271,7 +272,7 @@ func (m *manager) makePrimaryNetworkValidatorSet( validatorSet, currentHeight, lastDiffHeight, - constants.PlatformChainID, + constants.PrimaryNetworkID, ) if err != nil { return nil, 0, err @@ -282,6 +283,7 @@ func (m *manager) makePrimaryNetworkValidatorSet( validatorSet, currentHeight, lastDiffHeight, + constants.PrimaryNetworkID, ) return validatorSet, currentHeight, err } @@ -348,6 +350,10 @@ func (m *manager) makeSubnetValidatorSet( subnetValidatorSet, currentHeight, lastDiffHeight, + // TODO: Etna introduces L1s whose validators specify their own public + // keys, rather than inheriting them from the primary network. + // Therefore, this will need to use the subnetID after Etna. + constants.PrimaryNetworkID, ) return subnetValidatorSet, currentHeight, err } From 290ef974191e6ebef1660b69c391a494e2ac9455 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Mon, 21 Oct 2024 20:31:16 -0400 Subject: [PATCH 05/20] Update mocks --- vms/platformvm/state/mock_state.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index a1759398257..9d256842251 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -149,17 +149,17 @@ func (mr *MockStateMockRecorder) AddUTXO(utxo any) *gomock.Call { } // ApplyValidatorPublicKeyDiffs mocks base method. -func (m *MockState) ApplyValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64) error { +func (m *MockState) ApplyValidatorPublicKeyDiffs(ctx context.Context, validators map[ids.NodeID]*validators.GetValidatorOutput, startHeight, endHeight uint64, subnetID ids.ID) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ApplyValidatorPublicKeyDiffs", ctx, validators, startHeight, endHeight) + ret := m.ctrl.Call(m, "ApplyValidatorPublicKeyDiffs", ctx, validators, startHeight, endHeight, subnetID) ret0, _ := ret[0].(error) return ret0 } // ApplyValidatorPublicKeyDiffs indicates an expected call of ApplyValidatorPublicKeyDiffs. -func (mr *MockStateMockRecorder) ApplyValidatorPublicKeyDiffs(ctx, validators, startHeight, endHeight any) *gomock.Call { +func (mr *MockStateMockRecorder) ApplyValidatorPublicKeyDiffs(ctx, validators, startHeight, endHeight, subnetID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyValidatorPublicKeyDiffs", reflect.TypeOf((*MockState)(nil).ApplyValidatorPublicKeyDiffs), ctx, validators, startHeight, endHeight) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyValidatorPublicKeyDiffs", reflect.TypeOf((*MockState)(nil).ApplyValidatorPublicKeyDiffs), ctx, validators, startHeight, endHeight, subnetID) } // ApplyValidatorWeightDiffs mocks base method. From 9155e1fd37f436f772568640f35cc813c4405c21 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Mon, 21 Oct 2024 20:34:50 -0400 Subject: [PATCH 06/20] Fix tests --- vms/platformvm/state/state_test.go | 850 ++++++++++------------------- 1 file changed, 283 insertions(+), 567 deletions(-) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index ba8dbd65481..419586f7871 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -5,7 +5,6 @@ package state import ( "context" - "fmt" "math" "math/rand" "sync" @@ -28,6 +27,7 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/maybe" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/utils/wrappers" @@ -110,620 +110,336 @@ func TestStateSyncGenesis(t *testing.T) { ) } -// Whenever we store a staker, a whole bunch a data structures are updated +// Whenever we store a staker, a whole bunch of data structures are updated // This test is meant to capture which updates are carried out func TestPersistStakers(t *testing.T) { - tests := map[string]struct { - // Insert or delete a staker to state and store it - storeStaker func(*require.Assertions, ids.ID /*=subnetID*/, *state) *Staker - - // Check that the staker is duly stored/removed in P-chain state - checkStakerInState func(*require.Assertions, *state, *Staker) - - // Check whether validators are duly reported in the validator set, - // with the right weight and showing the BLS key - checkValidatorsSet func(*require.Assertions, *state, *Staker) + const ( + primaryValidatorDuration = 28 * 24 * time.Hour + primaryDelegatorDuration = 14 * 24 * time.Hour + subnetValidatorDuration = 21 * 24 * time.Hour + subnetDelegatorDuration = 14 * 24 * time.Hour + + primaryValidatorReward = iota + primaryDelegatorReward + ) + var ( + primaryValidatorStartTime = time.Now().Truncate(time.Second) + primaryValidatorEndTime = primaryValidatorStartTime.Add(primaryValidatorDuration) + primaryValidatorEndTimeUnix = uint64(primaryValidatorEndTime.Unix()) + + primaryDelegatorStartTime = primaryValidatorStartTime + primaryDelegatorEndTime = primaryDelegatorStartTime.Add(primaryDelegatorDuration) + primaryDelegatorEndTimeUnix = uint64(primaryDelegatorEndTime.Unix()) + + primaryValidatorData = txs.Validator{ + NodeID: ids.GenerateTestNodeID(), + End: primaryValidatorEndTimeUnix, + Wght: 1234, + } + primaryDelegatorData = txs.Validator{ + NodeID: primaryValidatorData.NodeID, + End: primaryDelegatorEndTimeUnix, + Wght: 6789, + } + ) - // Check that node duly track stakers uptimes - checkValidatorUptimes func(*require.Assertions, *state, *Staker) + unsignedAddPrimaryNetworkValidator := createPermissionlessValidatorTx(t, constants.PrimaryNetworkID, primaryValidatorData) + addPrimaryNetworkValidator := &txs.Tx{Unsigned: unsignedAddPrimaryNetworkValidator} + require.NoError(t, addPrimaryNetworkValidator.Initialize(txs.Codec)) - // Check whether weight/bls keys diffs are duly stored - checkDiffs func(*require.Assertions, *state, *Staker, uint64) - }{ - "add current validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(endTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - ) + primaryNetworkPendingValidatorStaker, err := NewPendingStaker( + addPrimaryNetworkValidator.ID(), + unsignedAddPrimaryNetworkValidator, + ) + require.NoError(t, err) - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) + primaryNetworkCurrentValidatorStaker, err := NewCurrentStaker( + addPrimaryNetworkValidator.ID(), + unsignedAddPrimaryNetworkValidator, + primaryValidatorStartTime, + primaryValidatorReward, + ) + require.NoError(t, err) - staker, err := NewCurrentStaker( - addPermValTx.ID(), - utx, - time.Unix(startTime, 0), - validatorReward, - ) - r.NoError(err) + unsignedAddPrimaryNetworkDelegator := createPermissionlessDelegatorTx(constants.PrimaryNetworkID, primaryDelegatorData) + addPrimaryNetworkDelegator := &txs.Tx{Unsigned: unsignedAddPrimaryNetworkDelegator} + require.NoError(t, addPrimaryNetworkDelegator.Initialize(txs.Codec)) - r.NoError(s.PutCurrentValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - retrievedStaker, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.Equal(staker, retrievedStaker) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - r.Equal( - &validators.GetValidatorOutput{ - NodeID: staker.NodeID, - PublicKey: staker.PublicKey, - Weight: staker.Weight, - }, - valsMap[staker.NodeID], - ) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - upDuration, lastUpdated, err := s.GetUptime(staker.NodeID) - if staker.SubnetID != constants.PrimaryNetworkID { - // only primary network validators have uptimes - r.ErrorIs(err, database.ErrNotFound) - } else { - r.NoError(err) - r.Equal(upDuration, time.Duration(0)) - r.Equal(lastUpdated, staker.StartTime) - } - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: false, - Amount: staker.Weight, - }, weightDiff) - - blsDiffBytes, err := s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - if staker.SubnetID == constants.PrimaryNetworkID { - r.NoError(err) - r.Nil(blsDiffBytes) - } else { - r.ErrorIs(err, database.ErrNotFound) - } - }, - }, - "add current delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert the delegator and its validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(valEndTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - delegatorReward uint64 = 5432 - ) + primaryNetworkPendingDelegatorStaker, err := NewPendingStaker( + addPrimaryNetworkDelegator.ID(), + unsignedAddPrimaryNetworkDelegator, + ) + require.NoError(t, err) - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) + primaryNetworkCurrentDelegatorStaker, err := NewCurrentStaker( + addPrimaryNetworkDelegator.ID(), + unsignedAddPrimaryNetworkDelegator, + primaryDelegatorStartTime, + primaryDelegatorReward, + ) + require.NoError(t, err) - val, err := NewCurrentStaker( - addPermValTx.ID(), - utxVal, - time.Unix(valStartTime, 0), - validatorReward, - ) - r.NoError(err) + tests := map[string]struct { + initialStakers []*Staker + initialTxs []*txs.Tx - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) + // Staker to insert or remove + staker *Staker + tx *txs.Tx // If tx is nil, the staker is being removed - del, err := NewCurrentStaker( - addPermDelTx.ID(), - utxDel, - time.Unix(delStartTime, 0), - delegatorReward, - ) - r.NoError(err) + // Check that the staker is duly stored/removed in P-chain state + expectedCurrentValidator *Staker + expectedPendingValidator *Staker + expectedCurrentDelegators []*Staker + expectedPendingDelegators []*Staker - r.NoError(s.PutCurrentValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + // Check that the validator entry has been set correctly in the + // in-memory validator set. + expectedValidatorSetOutput *validators.GetValidatorOutput - s.PutCurrentDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetCurrentDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.True(delIt.Next()) - retrievedDelegator := delIt.Value() - r.False(delIt.Next()) - delIt.Release() - r.Equal(staker, retrievedDelegator) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - val, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - valOut := valsMap[staker.NodeID] - r.Equal(valOut.NodeID, staker.NodeID) - r.Equal(valOut.Weight, val.Weight+staker.Weight) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // validator's weight must increase of delegator's weight amount - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: false, - Amount: staker.Weight, - }, weightDiff) + // Check whether weight/bls keys diffs are duly stored + expectedWeightDiff *ValidatorWeightDiff + expectedPublicKeyDiff maybe.Maybe[*bls.PublicKey] + }{ + "add current primary network validator": { + staker: primaryNetworkCurrentValidatorStaker, + tx: addPrimaryNetworkValidator, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, - "add pending validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(startTime), - End: uint64(endTime), - Wght: 1234, - } - ) - - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - staker, err := NewPendingStaker( - addPermValTx.ID(), - utx, - ) - r.NoError(err) - - r.NoError(s.PutPendingValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - retrievedStaker, err := s.GetPendingValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.Equal(staker, retrievedStaker) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - // pending validators are not showed in validators set - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - // pending validators uptime is not tracked - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) + "add current primary network delegator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkCurrentDelegatorStaker, + tx: addPrimaryNetworkDelegator, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedCurrentDelegators: []*Staker{primaryNetworkCurrentDelegatorStaker}, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentDelegatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentDelegatorStaker.Weight + primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // pending validators weight diff and bls diffs are not stored - _, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) - - _, err = s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) + }, + "add pending primary network validator": { + staker: primaryNetworkPendingValidatorStaker, + tx: addPrimaryNetworkValidator, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, + }, + "add pending primary network delegator": { + initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkPendingDelegatorStaker, + tx: addPrimaryNetworkDelegator, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, + expectedPendingDelegators: []*Staker{primaryNetworkPendingDelegatorStaker}, + }, + "delete current primary network validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkCurrentValidatorStaker, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some(primaryNetworkCurrentValidatorStaker.PublicKey), }, - "add pending delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert the delegator and its validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(valStartTime), - End: uint64(valEndTime), - Wght: 1234, - } - - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - Start: uint64(delStartTime), - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - ) - - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - val, err := NewPendingStaker(addPermValTx.ID(), utxVal) - r.NoError(err) - - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) - - del, err := NewPendingStaker(addPermDelTx.ID(), utxDel) - r.NoError(err) - - r.NoError(s.PutPendingValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - s.PutPendingDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - return del + "delete current primary network delegator": { + initialStakers: []*Staker{ + primaryNetworkCurrentValidatorStaker, + primaryNetworkCurrentDelegatorStaker, + }, + initialTxs: []*txs.Tx{ + addPrimaryNetworkValidator, + addPrimaryNetworkDelegator, + }, + staker: primaryNetworkCurrentDelegatorStaker, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetPendingDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.True(delIt.Next()) - retrievedDelegator := delIt.Value() - r.False(delIt.Next()) - delIt.Release() - r.Equal(staker, retrievedDelegator) + }, + "delete pending primary network validator": { + initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkPendingValidatorStaker, + }, + "delete pending primary network delegator": { + initialStakers: []*Staker{ + primaryNetworkPendingValidatorStaker, + primaryNetworkPendingDelegatorStaker, }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) + initialTxs: []*txs.Tx{ + addPrimaryNetworkValidator, + addPrimaryNetworkDelegator, }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(*require.Assertions, *state, *Staker, uint64) {}, + staker: primaryNetworkPendingDelegatorStaker, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, }, - "delete current validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // add them remove the validator - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(endTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - ) + } - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + require := require.New(t) - staker, err := NewCurrentStaker( - addPermValTx.ID(), - utx, - time.Unix(startTime, 0), - validatorReward, - ) - r.NoError(err) + db := memdb.New() + state := newTestState(t, db) + + // create and store the initial stakers + for _, staker := range test.initialStakers { + switch { + case staker.Priority.IsCurrentValidator(): + require.NoError(state.PutCurrentValidator(staker)) + case staker.Priority.IsPendingValidator(): + require.NoError(state.PutPendingValidator(staker)) + case staker.Priority.IsCurrentDelegator(): + state.PutCurrentDelegator(staker) + case staker.Priority.IsPendingDelegator(): + state.PutPendingDelegator(staker) + } + } + for _, tx := range test.initialTxs { + state.AddTx(tx, status.Committed) + } - r.NoError(s.PutCurrentValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + state.SetHeight(0) + require.NoError(state.Commit()) - s.DeleteCurrentValidator(staker) - r.NoError(s.Commit()) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - _, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - // deleted validators are not showed in the validators set anymore - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - // uptimes of delete validators are dropped - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: true, - Amount: staker.Weight, - }, weightDiff) - - blsDiffBytes, err := s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - if staker.SubnetID == constants.PrimaryNetworkID { - r.NoError(err) - r.Equal(bls.PublicKeyFromValidUncompressedBytes(blsDiffBytes), staker.PublicKey) + // create and store the staker under test + switch { + case test.staker.Priority.IsCurrentValidator(): + if test.tx != nil { + require.NoError(state.PutCurrentValidator(test.staker)) } else { - r.ErrorIs(err, database.ErrNotFound) + state.DeleteCurrentValidator(test.staker) } - }, - }, - "delete current delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert validator and delegator, then remove the delegator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(valEndTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - delegatorReward uint64 = 5432 - ) - - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - val, err := NewCurrentStaker( - addPermValTx.ID(), - utxVal, - time.Unix(valStartTime, 0), - validatorReward, - ) - r.NoError(err) - - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) + case test.staker.Priority.IsPendingValidator(): + if test.tx != nil { + require.NoError(state.PutPendingValidator(test.staker)) + } else { + state.DeletePendingValidator(test.staker) + } + case test.staker.Priority.IsCurrentDelegator(): + if test.tx != nil { + state.PutCurrentDelegator(test.staker) + } else { + state.DeleteCurrentDelegator(test.staker) + } + case test.staker.Priority.IsPendingDelegator(): + if test.tx != nil { + state.PutPendingDelegator(test.staker) + } else { + state.DeletePendingDelegator(test.staker) + } + } + if test.tx != nil { + state.AddTx(test.tx, status.Committed) + } - del, err := NewCurrentStaker( - addPermDelTx.ID(), - utxDel, - time.Unix(delStartTime, 0), - delegatorReward, - ) - r.NoError(err) + state.SetHeight(1) + require.NoError(state.Commit()) - r.NoError(s.PutCurrentValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker + // Perform the checks once immediately after committing to the + // state, and once after re-loading the state from disk. + for i := 0; i < 2; i++ { + currentValidator, err := state.GetCurrentValidator(test.staker.SubnetID, test.staker.NodeID) + if test.expectedCurrentValidator == nil { + require.ErrorIs(err, database.ErrNotFound) - s.PutCurrentDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + // Only current validators should have uptimes + _, _, err := state.GetUptime(test.staker.NodeID) + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) + require.Equal(test.expectedCurrentValidator, currentValidator) + + // Current validators should also have uptimes + upDuration, lastUpdated, err := state.GetUptime(currentValidator.NodeID) + require.NoError(err) + require.Zero(upDuration) + require.Equal(currentValidator.StartTime, lastUpdated) + } - s.DeleteCurrentDelegator(del) - r.NoError(s.Commit()) + pendingValidator, err := state.GetPendingValidator(test.staker.SubnetID, test.staker.NodeID) + if test.expectedPendingValidator == nil { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) + require.Equal(test.expectedPendingValidator, pendingValidator) + } - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetCurrentDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.False(delIt.Next()) - delIt.Release() - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - val, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - valOut := valsMap[staker.NodeID] - r.Equal(valOut.NodeID, staker.NodeID) - r.Equal(valOut.Weight, val.Weight) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // validator's weight must decrease of delegator's weight amount - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: true, - Amount: staker.Weight, - }, weightDiff) - }, - }, - "delete pending validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(startTime), - End: uint64(endTime), - Wght: 1234, - } + it, err := state.GetCurrentDelegatorIterator(test.staker.SubnetID, test.staker.NodeID) + require.NoError(err) + require.Equal( + test.expectedCurrentDelegators, + iterator.ToSlice(it), ) - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - staker, err := NewPendingStaker( - addPermValTx.ID(), - utx, + it, err = state.GetPendingDelegatorIterator(test.staker.SubnetID, test.staker.NodeID) + require.NoError(err) + require.Equal( + test.expectedPendingDelegators, + iterator.ToSlice(it), ) - r.NoError(err) - - r.NoError(s.PutPendingValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - s.DeletePendingValidator(staker) - r.NoError(s.Commit()) - - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - _, err := s.GetPendingValidator(staker.SubnetID, staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - _, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) - - _, err = s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) - }, - }, - "delete pending delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert validator and delegator the remove the validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(valStartTime), - End: uint64(valEndTime), - Wght: 1234, - } - - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - Start: uint64(delStartTime), - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } + require.Equal( + test.expectedValidatorSetOutput, + state.validators.GetMap(test.staker.SubnetID)[test.staker.NodeID], ) - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - val, err := NewPendingStaker(addPermValTx.ID(), utxVal) - r.NoError(err) - - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) - - del, err := NewPendingStaker(addPermDelTx.ID(), utxDel) - r.NoError(err) - - r.NoError(s.PutPendingValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - - s.PutPendingDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - s.DeletePendingDelegator(del) - r.NoError(s.Commit()) - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetPendingDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.False(delIt.Next()) - delIt.Release() - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(*require.Assertions, *state, *Staker, uint64) {}, - }, - } - - subnetIDs := []ids.ID{constants.PrimaryNetworkID, ids.GenerateTestID()} - for _, subnetID := range subnetIDs { - for name, test := range tests { - t.Run(fmt.Sprintf("%s - subnetID %s", name, subnetID), func(t *testing.T) { - require := require.New(t) - - db := memdb.New() - state := newTestState(t, db) - - // create and store the staker - staker := test.storeStaker(require, subnetID, state) + diffKey := marshalDiffKey(test.staker.SubnetID, 1, test.staker.NodeID) + weightDiffBytes, err := state.validatorWeightDiffsDB.Get(diffKey) + if test.expectedWeightDiff == nil { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) - // check all relevant data are stored - test.checkStakerInState(require, state, staker) - test.checkValidatorsSet(require, state, staker) - test.checkValidatorUptimes(require, state, staker) - test.checkDiffs(require, state, staker, 0 /*height*/) + weightDiff, err := unmarshalWeightDiff(weightDiffBytes) + require.NoError(err) + require.Equal(test.expectedWeightDiff, weightDiff) + } - // rebuild the state - rebuiltState := newTestState(t, db) + publicKeyDiffBytes, err := state.validatorPublicKeyDiffsDB.Get(diffKey) + if test.expectedPublicKeyDiff.IsNothing() { + require.ErrorIs(err, database.ErrNotFound) + } else if expectedPublicKeyDiff := test.expectedPublicKeyDiff.Value(); expectedPublicKeyDiff == nil { + require.NoError(err) + require.Empty(publicKeyDiffBytes) + } else { + require.NoError(err) + require.Equal(expectedPublicKeyDiff, bls.PublicKeyFromValidUncompressedBytes(publicKeyDiffBytes)) + } - // check again that all relevant data are still available in rebuilt state - test.checkStakerInState(require, rebuiltState, staker) - test.checkValidatorsSet(require, rebuiltState, staker) - test.checkValidatorUptimes(require, rebuiltState, staker) - test.checkDiffs(require, rebuiltState, staker, 0 /*height*/) - }) - } + // re-load the state from disk + state = newTestState(t, db) + } + }) } } -func createPermissionlessValidatorTx(r *require.Assertions, subnetID ids.ID, validatorsData txs.Validator) *txs.AddPermissionlessValidatorTx { +func createPermissionlessValidatorTx(t testing.TB, subnetID ids.ID, validatorsData txs.Validator) *txs.AddPermissionlessValidatorTx { var sig signer.Signer = &signer.Empty{} if subnetID == constants.PrimaryNetworkID { sk, err := bls.NewSecretKey() - r.NoError(err) + require.NoError(t, err) sig = signer.NewProofOfPossession(sk) } From 803d0c4d777aa587949df49f40b21b17cacafeed Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Tue, 22 Oct 2024 07:07:59 -0400 Subject: [PATCH 07/20] nit --- vms/platformvm/state/state_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 419586f7871..29d37c15a34 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -428,7 +428,7 @@ func TestPersistStakers(t *testing.T) { require.Equal(expectedPublicKeyDiff, bls.PublicKeyFromValidUncompressedBytes(publicKeyDiffBytes)) } - // re-load the state from disk + // re-load the state from disk for the second iteration state = newTestState(t, db) } }) From a2c777337e26d7e63426644d61b00c8e0d21a81c Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Tue, 22 Oct 2024 08:57:04 -0400 Subject: [PATCH 08/20] Update test and populate public keys during startup --- vms/platformvm/state/state.go | 21 ++++++--- vms/platformvm/state/state_test.go | 72 ++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index bed7b4746fb..8a81931f111 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -1718,19 +1718,28 @@ func (s *state) loadPendingValidators() error { // Invariant: initValidatorSets requires loadCurrentValidators to have already // been called. func (s *state) initValidatorSets() error { - for subnetID, validators := range s.currentStakers.validators { + primaryNetworkValidators := s.currentStakers.validators[constants.PrimaryNetworkID] + for subnetID, subnetValidators := range s.currentStakers.validators { if s.validators.Count(subnetID) != 0 { // Enforce the invariant that the validator set is empty here. return fmt.Errorf("%w: %s", errValidatorSetAlreadyPopulated, subnetID) } - for nodeID, validator := range validators { - validatorStaker := validator.validator - if err := s.validators.AddStaker(subnetID, nodeID, validatorStaker.PublicKey, validatorStaker.TxID, validatorStaker.Weight); err != nil { + for nodeID, subnetValidator := range subnetValidators { + primaryValidator, ok := primaryNetworkValidators[nodeID] + if !ok { + return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) + } + + var ( + primaryStaker = primaryValidator.validator + subnetStaker = subnetValidator.validator + ) + if err := s.validators.AddStaker(subnetID, nodeID, primaryStaker.PublicKey, subnetStaker.TxID, subnetStaker.Weight); err != nil { return err } - delegatorIterator := iterator.FromTree(validator.delegators) + delegatorIterator := iterator.FromTree(subnetValidator.delegators) for delegatorIterator.Next() { delegatorStaker := delegatorIterator.Value() if err := s.validators.AddWeight(subnetID, nodeID, delegatorStaker.Weight); err != nil { @@ -2106,7 +2115,7 @@ func (s *state) writeCurrentStakersSubnetDiff( // This should never happen as the primary network diffs are // written last and subnet validator times must be a subset // of the primary network validator times. - return errMissingPrimaryNetworkValidator + return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) } } diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 29d37c15a34..b725761bab9 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -117,10 +117,10 @@ func TestPersistStakers(t *testing.T) { primaryValidatorDuration = 28 * 24 * time.Hour primaryDelegatorDuration = 14 * 24 * time.Hour subnetValidatorDuration = 21 * 24 * time.Hour - subnetDelegatorDuration = 14 * 24 * time.Hour primaryValidatorReward = iota primaryDelegatorReward + subnetValidatorReward ) var ( primaryValidatorStartTime = time.Now().Truncate(time.Second) @@ -131,6 +131,10 @@ func TestPersistStakers(t *testing.T) { primaryDelegatorEndTime = primaryDelegatorStartTime.Add(primaryDelegatorDuration) primaryDelegatorEndTimeUnix = uint64(primaryDelegatorEndTime.Unix()) + subnetValidatorStartTime = primaryValidatorStartTime + subnetValidatorEndTime = subnetValidatorStartTime.Add(subnetValidatorDuration) + subnetValidatorEndTimeUnix = uint64(subnetValidatorEndTime.Unix()) + primaryValidatorData = txs.Validator{ NodeID: ids.GenerateTestNodeID(), End: primaryValidatorEndTimeUnix, @@ -141,6 +145,13 @@ func TestPersistStakers(t *testing.T) { End: primaryDelegatorEndTimeUnix, Wght: 6789, } + subnetValidatorData = txs.Validator{ + NodeID: primaryValidatorData.NodeID, + End: subnetValidatorEndTimeUnix, + Wght: 9876, + } + + subnetID = ids.GenerateTestID() ) unsignedAddPrimaryNetworkValidator := createPermissionlessValidatorTx(t, constants.PrimaryNetworkID, primaryValidatorData) @@ -179,6 +190,18 @@ func TestPersistStakers(t *testing.T) { ) require.NoError(t, err) + unsignedAddSubnetValidator := createPermissionlessValidatorTx(t, subnetID, subnetValidatorData) + addSubnetValidator := &txs.Tx{Unsigned: unsignedAddSubnetValidator} + require.NoError(t, addSubnetValidator.Initialize(txs.Codec)) + + subnetCurrentValidatorStaker, err := NewCurrentStaker( + addSubnetValidator.ID(), + unsignedAddSubnetValidator, + subnetValidatorStartTime, + subnetValidatorReward, + ) + require.NoError(t, err) + tests := map[string]struct { initialStakers []*Staker initialTxs []*txs.Tx @@ -246,6 +269,23 @@ func TestPersistStakers(t *testing.T) { expectedPendingValidator: primaryNetworkPendingValidatorStaker, expectedPendingDelegators: []*Staker{primaryNetworkPendingDelegatorStaker}, }, + "add current subnet validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: subnetCurrentValidatorStaker, + tx: addSubnetValidator, + expectedCurrentValidator: subnetCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: subnetCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: subnetCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: subnetCurrentValidatorStaker.Weight, + }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), + }, "delete current primary network validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, @@ -294,6 +334,16 @@ func TestPersistStakers(t *testing.T) { staker: primaryNetworkPendingDelegatorStaker, expectedPendingValidator: primaryNetworkPendingValidatorStaker, }, + "delete current subnet validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker, subnetCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator, addSubnetValidator}, + staker: subnetCurrentValidatorStaker, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: subnetCurrentValidatorStaker.Weight, + }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), + }, } for name, test := range tests { @@ -364,18 +414,22 @@ func TestPersistStakers(t *testing.T) { if test.expectedCurrentValidator == nil { require.ErrorIs(err, database.ErrNotFound) - // Only current validators should have uptimes - _, _, err := state.GetUptime(test.staker.NodeID) - require.ErrorIs(err, database.ErrNotFound) + if test.staker.SubnetID == constants.PrimaryNetworkID { + // Uptimes are only considered for primary network validators + _, _, err := state.GetUptime(test.staker.NodeID) + require.ErrorIs(err, database.ErrNotFound) + } } else { require.NoError(err) require.Equal(test.expectedCurrentValidator, currentValidator) - // Current validators should also have uptimes - upDuration, lastUpdated, err := state.GetUptime(currentValidator.NodeID) - require.NoError(err) - require.Zero(upDuration) - require.Equal(currentValidator.StartTime, lastUpdated) + if test.staker.SubnetID == constants.PrimaryNetworkID { + // Uptimes are only considered for primary network validators + upDuration, lastUpdated, err := state.GetUptime(currentValidator.NodeID) + require.NoError(err) + require.Zero(upDuration) + require.Equal(currentValidator.StartTime, lastUpdated) + } } pendingValidator, err := state.GetPendingValidator(test.staker.SubnetID, test.staker.NodeID) From 95c42a16942d6ba0a8a67a0ebd29ed7129310efa Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Wed, 23 Oct 2024 11:39:45 -0400 Subject: [PATCH 09/20] comment --- vms/platformvm/state/state.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 8a81931f111..725b9e010c4 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -1726,6 +1726,8 @@ func (s *state) initValidatorSets() error { } for nodeID, subnetValidator := range subnetValidators { + // The subnet validator's Public Key is inherited from the + // corresponding primary network validator. primaryValidator, ok := primaryNetworkValidators[nodeID] if !ok { return fmt.Errorf("%w: %s", errMissingPrimaryNetworkValidator, nodeID) From e8f4c99963ca08ccb7cc91048d5507fed95f315f Mon Sep 17 00:00:00 2001 From: yacovm Date: Wed, 23 Oct 2024 17:59:59 +0200 Subject: [PATCH 10/20] Move RPC metrics registration after its client's initialization (#3488) Signed-off-by: Yacov Manevich --- vms/rpcchainvm/vm_client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vms/rpcchainvm/vm_client.go b/vms/rpcchainvm/vm_client.go index c5a358df8ed..05add90700d 100644 --- a/vms/rpcchainvm/vm_client.go +++ b/vms/rpcchainvm/vm_client.go @@ -156,10 +156,6 @@ func (vm *VMClient) Initialize( return err } - if err := chainCtx.Metrics.Register("", vm); err != nil { - return err - } - // Initialize the database dbServerListener, err := grpcutils.NewListener() if err != nil { @@ -229,6 +225,10 @@ func (vm *VMClient) Initialize( return err } + if err := chainCtx.Metrics.Register("", vm); err != nil { + return err + } + id, err := ids.ToID(resp.LastAcceptedId) if err != nil { return err From 552a800639d6778c493e4eb21a86aa0d03fd66d6 Mon Sep 17 00:00:00 2001 From: Darioush Jalali Date: Wed, 23 Oct 2024 09:04:35 -0700 Subject: [PATCH 11/20] database: add applicable dbtests for linkeddb (#3486) Signed-off-by: Darioush Jalali --- database/dbtest/dbtest.go | 53 ++++++++++++++++++++++++----------- database/linkeddb/db_test.go | 25 +++++++++++++++++ database/linkeddb/linkeddb.go | 2 ++ 3 files changed, 63 insertions(+), 17 deletions(-) create mode 100644 database/linkeddb/db_test.go diff --git a/database/dbtest/dbtest.go b/database/dbtest/dbtest.go index dc203db09d1..92b2b2cd46b 100644 --- a/database/dbtest/dbtest.go +++ b/database/dbtest/dbtest.go @@ -22,12 +22,20 @@ import ( "github.com/ava-labs/avalanchego/utils/units" ) +// TestsBasic is a list of all basic database tests that require only +// a KeyValueReaderWriterDeleter. +var TestsBasic = map[string]func(t *testing.T, db database.KeyValueReaderWriterDeleter){ + "SimpleKeyValue": TestSimpleKeyValue, + "OverwriteKeyValue": TestOverwriteKeyValue, + "EmptyKey": TestEmptyKey, + "KeyEmptyValue": TestKeyEmptyValue, + "MemorySafetyDatabase": TestMemorySafetyDatabase, + "ModifyValueAfterPut": TestModifyValueAfterPut, + "PutGetEmpty": TestPutGetEmpty, +} + // Tests is a list of all database tests var Tests = map[string]func(t *testing.T, db database.Database){ - "SimpleKeyValue": TestSimpleKeyValue, - "OverwriteKeyValue": TestOverwriteKeyValue, - "EmptyKey": TestEmptyKey, - "KeyEmptyValue": TestKeyEmptyValue, "SimpleKeyValueClosed": TestSimpleKeyValueClosed, "NewBatchClosed": TestNewBatchClosed, "BatchPut": TestBatchPut, @@ -49,23 +57,29 @@ var Tests = map[string]func(t *testing.T, db database.Database){ "IteratorError": TestIteratorError, "IteratorErrorAfterRelease": TestIteratorErrorAfterRelease, "CompactNoPanic": TestCompactNoPanic, - "MemorySafetyDatabase": TestMemorySafetyDatabase, "MemorySafetyBatch": TestMemorySafetyBatch, "AtomicClear": TestAtomicClear, "Clear": TestClear, "AtomicClearPrefix": TestAtomicClearPrefix, "ClearPrefix": TestClearPrefix, - "ModifyValueAfterPut": TestModifyValueAfterPut, "ModifyValueAfterBatchPut": TestModifyValueAfterBatchPut, "ModifyValueAfterBatchPutReplay": TestModifyValueAfterBatchPutReplay, "ConcurrentBatches": TestConcurrentBatches, "ManySmallConcurrentKVPairBatches": TestManySmallConcurrentKVPairBatches, - "PutGetEmpty": TestPutGetEmpty, +} + +func init() { + // Add all basic database tests to the database tests + for name, test := range TestsBasic { + Tests[name] = func(t *testing.T, db database.Database) { + test(t, db) + } + } } // TestSimpleKeyValue tests to make sure that simple Put + Get + Delete + Has // calls return the expected values. -func TestSimpleKeyValue(t *testing.T, db database.Database) { +func TestSimpleKeyValue(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte("hello") @@ -101,7 +115,7 @@ func TestSimpleKeyValue(t *testing.T, db database.Database) { require.NoError(db.Delete(key)) } -func TestOverwriteKeyValue(t *testing.T, db database.Database) { +func TestOverwriteKeyValue(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte("hello") @@ -117,7 +131,7 @@ func TestOverwriteKeyValue(t *testing.T, db database.Database) { require.Equal(value2, gotValue) } -func TestKeyEmptyValue(t *testing.T, db database.Database) { +func TestKeyEmptyValue(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte("hello") @@ -133,7 +147,7 @@ func TestKeyEmptyValue(t *testing.T, db database.Database) { require.Empty(value) } -func TestEmptyKey(t *testing.T, db database.Database) { +func TestEmptyKey(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) var ( @@ -202,7 +216,7 @@ func TestSimpleKeyValueClosed(t *testing.T, db database.Database) { // TestMemorySafetyDatabase ensures it is safe to modify a key after passing it // to Database.Put and Database.Get. -func TestMemorySafetyDatabase(t *testing.T, db database.Database) { +func TestMemorySafetyDatabase(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte("1key") @@ -211,9 +225,14 @@ func TestMemorySafetyDatabase(t *testing.T, db database.Database) { key2 := []byte("2key") value2 := []byte("value2") - // Put both K/V pairs in the database + // Put key in the database directly require.NoError(db.Put(key, value)) - require.NoError(db.Put(key2, value2)) + + // Put key2 in the database by modifying key, which should be safe + // to modify after the Put call + key[0] = key2[0] + require.NoError(db.Put(key, value2)) + key[0] = keyCopy[0] // Get the value for [key] gotVal, err := db.Get(key) @@ -1042,7 +1061,7 @@ func testClearPrefix(t *testing.T, db database.Database, clearF func(database.Da require.NoError(db.Close()) } -func TestModifyValueAfterPut(t *testing.T, db database.Database) { +func TestModifyValueAfterPut(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte{1} @@ -1166,7 +1185,7 @@ func runConcurrentBatches( return eg.Wait() } -func TestPutGetEmpty(t *testing.T, db database.Database) { +func TestPutGetEmpty(t *testing.T, db database.KeyValueReaderWriterDeleter) { require := require.New(t) key := []byte("hello") @@ -1184,7 +1203,7 @@ func TestPutGetEmpty(t *testing.T, db database.Database) { require.Empty(value) // May be nil or empty byte slice. } -func FuzzKeyValue(f *testing.F, db database.Database) { +func FuzzKeyValue(f *testing.F, db database.KeyValueReaderWriterDeleter) { f.Fuzz(func(t *testing.T, key []byte, value []byte) { require := require.New(t) diff --git a/database/linkeddb/db_test.go b/database/linkeddb/db_test.go new file mode 100644 index 00000000000..cad51252bf3 --- /dev/null +++ b/database/linkeddb/db_test.go @@ -0,0 +1,25 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package linkeddb + +import ( + "testing" + + "github.com/ava-labs/avalanchego/database/dbtest" + "github.com/ava-labs/avalanchego/database/memdb" +) + +func TestInterface(t *testing.T) { + for name, test := range dbtest.TestsBasic { + t.Run(name, func(t *testing.T) { + db := NewDefault(memdb.New()) + test(t, db) + }) + } +} + +func FuzzKeyValue(f *testing.F) { + db := NewDefault(memdb.New()) + dbtest.FuzzKeyValue(f, db) +} diff --git a/database/linkeddb/linkeddb.go b/database/linkeddb/linkeddb.go index b7bc6867976..4e609d329cf 100644 --- a/database/linkeddb/linkeddb.go +++ b/database/linkeddb/linkeddb.go @@ -108,6 +108,8 @@ func (ldb *linkedDB) Put(key, value []byte) error { } // The key isn't currently in the list, so we should add it as the head. + // Note we will copy the key so it's safe to store references to it. + key = slices.Clone(key) newHead := node{Value: slices.Clone(value)} if headKey, err := ldb.getHeadKey(); err == nil { // The list currently has a head, so we need to update the old head. From 041450690a1db1b54ebcbf2cf5f0abd17f4a1964 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Wed, 23 Oct 2024 12:08:07 -0400 Subject: [PATCH 12/20] Add SoV Excess to P-chain state (#3482) --- .../block/executor/proposal_block_test.go | 2 ++ .../block/executor/standard_block_test.go | 2 ++ .../block/executor/verifier_test.go | 5 ++++ vms/platformvm/state/diff.go | 11 ++++++++ vms/platformvm/state/diff_test.go | 25 +++++++++++++++++ vms/platformvm/state/mock_chain.go | 26 ++++++++++++++++++ vms/platformvm/state/mock_diff.go | 26 ++++++++++++++++++ vms/platformvm/state/mock_state.go | 26 ++++++++++++++++++ vms/platformvm/state/state.go | 27 +++++++++++++++++++ vms/platformvm/state/state_test.go | 16 +++++++++++ 10 files changed, 166 insertions(+) diff --git a/vms/platformvm/block/executor/proposal_block_test.go b/vms/platformvm/block/executor/proposal_block_test.go index 4a4154b02e5..c0a597d7733 100644 --- a/vms/platformvm/block/executor/proposal_block_test.go +++ b/vms/platformvm/block/executor/proposal_block_test.go @@ -91,6 +91,7 @@ func TestApricotProposalBlockTimeVerification(t *testing.T) { // setup state to validate proposal block transaction onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() onParentAccept.EXPECT().GetCurrentStakerIterator().Return( @@ -162,6 +163,7 @@ func TestBanffProposalBlockTimeVerification(t *testing.T) { onParentAccept := state.NewMockDiff(ctrl) onParentAccept.EXPECT().GetTimestamp().Return(parentTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() onParentAccept.EXPECT().GetCurrentSupply(constants.PrimaryNetworkID).Return(uint64(1000), nil).AnyTimes() diff --git a/vms/platformvm/block/executor/standard_block_test.go b/vms/platformvm/block/executor/standard_block_test.go index 8e62937c923..d9ad860d3d3 100644 --- a/vms/platformvm/block/executor/standard_block_test.go +++ b/vms/platformvm/block/executor/standard_block_test.go @@ -59,6 +59,7 @@ func TestApricotStandardBlockTimeVerification(t *testing.T) { chainTime := env.clk.Time().Truncate(time.Second) onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() // wrong height @@ -134,6 +135,7 @@ func TestBanffStandardBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetTimestamp().Return(chainTime).AnyTimes() onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() + onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() txID := ids.GenerateTestID() diff --git a/vms/platformvm/block/executor/verifier_test.go b/vms/platformvm/block/executor/verifier_test.go index f57b8fb4ed5..a076616701f 100644 --- a/vms/platformvm/block/executor/verifier_test.go +++ b/vms/platformvm/block/executor/verifier_test.go @@ -103,6 +103,7 @@ func TestVerifierVisitProposalBlock(t *testing.T) { // One call for each of onCommitState and onAbortState. parentOnAcceptState.EXPECT().GetTimestamp().Return(timestamp).Times(2) parentOnAcceptState.EXPECT().GetFeeState().Return(gas.State{}).Times(2) + parentOnAcceptState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(2) parentOnAcceptState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(2) backend := &backend{ @@ -335,6 +336,7 @@ func TestVerifierVisitStandardBlock(t *testing.T) { timestamp := time.Now() parentState.EXPECT().GetTimestamp().Return(timestamp).Times(1) parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + parentState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) parentStatelessBlk.EXPECT().Height().Return(uint64(1)).Times(1) mempool.EXPECT().Remove(apricotBlk.Txs()).Times(1) @@ -597,6 +599,7 @@ func TestBanffAbortBlockTimestampChecks(t *testing.T) { s.EXPECT().GetLastAccepted().Return(parentID).Times(3) s.EXPECT().GetTimestamp().Return(parentTime).Times(3) s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) + s.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(3) s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) @@ -695,6 +698,7 @@ func TestBanffCommitBlockTimestampChecks(t *testing.T) { s.EXPECT().GetLastAccepted().Return(parentID).Times(3) s.EXPECT().GetTimestamp().Return(parentTime).Times(3) s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) + s.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(3) s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) @@ -811,6 +815,7 @@ func TestVerifierVisitStandardBlockWithDuplicateInputs(t *testing.T) { parentStatelessBlk.EXPECT().Height().Return(uint64(1)).Times(1) parentState.EXPECT().GetTimestamp().Return(timestamp).Times(1) parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + parentState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) parentStatelessBlk.EXPECT().Parent().Return(grandParentID).Times(1) diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index 9fe6a62363c..da73854346e 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -37,6 +37,7 @@ type diff struct { timestamp time.Time feeState gas.State + sovExcess gas.Gas accruedFees uint64 // Subnet ID --> supply of native asset of the subnet @@ -80,6 +81,7 @@ func NewDiff( stateVersions: stateVersions, timestamp: parentState.GetTimestamp(), feeState: parentState.GetFeeState(), + sovExcess: parentState.GetSoVExcess(), accruedFees: parentState.GetAccruedFees(), expiryDiff: newExpiryDiff(), subnetOwners: make(map[ids.ID]fx.Owner), @@ -117,6 +119,14 @@ func (d *diff) SetFeeState(feeState gas.State) { d.feeState = feeState } +func (d *diff) GetSoVExcess() gas.Gas { + return d.sovExcess +} + +func (d *diff) SetSoVExcess(excess gas.Gas) { + d.sovExcess = excess +} + func (d *diff) GetAccruedFees() uint64 { return d.accruedFees } @@ -482,6 +492,7 @@ func (d *diff) DeleteUTXO(utxoID ids.ID) { func (d *diff) Apply(baseState Chain) error { baseState.SetTimestamp(d.timestamp) baseState.SetFeeState(d.feeState) + baseState.SetSoVExcess(d.sovExcess) baseState.SetAccruedFees(d.accruedFees) for subnetID, supply := range d.currentSupply { baseState.SetCurrentSupply(subnetID, supply) diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index 3625986d780..82e376c3b5f 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -68,6 +68,24 @@ func TestDiffFeeState(t *testing.T) { assertChainsEqual(t, state, d) } +func TestDiffSoVExcess(t *testing.T) { + require := require.New(t) + + state := newTestState(t, memdb.New()) + + d, err := NewDiffOn(state) + require.NoError(err) + + initialExcess := state.GetSoVExcess() + newExcess := initialExcess + 1 + d.SetSoVExcess(newExcess) + require.Equal(newExcess, d.GetSoVExcess()) + require.Equal(initialExcess, state.GetSoVExcess()) + + require.NoError(d.Apply(state)) + assertChainsEqual(t, state, d) +} + func TestDiffAccruedFees(t *testing.T) { require := require.New(t) @@ -272,6 +290,7 @@ func TestDiffCurrentValidator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -307,6 +326,7 @@ func TestDiffPendingValidator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -348,6 +368,7 @@ func TestDiffCurrentDelegator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -392,6 +413,7 @@ func TestDiffPendingDelegator(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -530,6 +552,7 @@ func TestDiffTx(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -628,6 +651,7 @@ func TestDiffUTXO(t *testing.T) { // Called in NewDiffOn state.EXPECT().GetTimestamp().Return(time.Now()).Times(1) state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) + state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) d, err := NewDiffOn(state) @@ -705,6 +729,7 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { require.Equal(expected.GetTimestamp(), actual.GetTimestamp()) require.Equal(expected.GetFeeState(), actual.GetFeeState()) + require.Equal(expected.GetSoVExcess(), actual.GetSoVExcess()) require.Equal(expected.GetAccruedFees(), actual.GetAccruedFees()) expectedCurrentSupply, err := expected.GetCurrentSupply(constants.PrimaryNetworkID) diff --git a/vms/platformvm/state/mock_chain.go b/vms/platformvm/state/mock_chain.go index 27daeae3a10..56c49592451 100644 --- a/vms/platformvm/state/mock_chain.go +++ b/vms/platformvm/state/mock_chain.go @@ -353,6 +353,20 @@ func (mr *MockChainMockRecorder) GetPendingValidator(subnetID, nodeID any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPendingValidator", reflect.TypeOf((*MockChain)(nil).GetPendingValidator), subnetID, nodeID) } +// GetSoVExcess mocks base method. +func (m *MockChain) GetSoVExcess() gas.Gas { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSoVExcess") + ret0, _ := ret[0].(gas.Gas) + return ret0 +} + +// GetSoVExcess indicates an expected call of GetSoVExcess. +func (mr *MockChainMockRecorder) GetSoVExcess() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSoVExcess", reflect.TypeOf((*MockChain)(nil).GetSoVExcess)) +} + // GetSubnetConversion mocks base method. func (m *MockChain) GetSubnetConversion(subnetID ids.ID) (SubnetConversion, error) { m.ctrl.T.Helper() @@ -572,6 +586,18 @@ func (mr *MockChainMockRecorder) SetFeeState(f any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFeeState", reflect.TypeOf((*MockChain)(nil).SetFeeState), f) } +// SetSoVExcess mocks base method. +func (m *MockChain) SetSoVExcess(e gas.Gas) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetSoVExcess", e) +} + +// SetSoVExcess indicates an expected call of SetSoVExcess. +func (mr *MockChainMockRecorder) SetSoVExcess(e any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSoVExcess", reflect.TypeOf((*MockChain)(nil).SetSoVExcess), e) +} + // SetSubnetConversion mocks base method. func (m *MockChain) SetSubnetConversion(subnetID ids.ID, c SubnetConversion) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_diff.go b/vms/platformvm/state/mock_diff.go index 8732fc49b40..b8362386af9 100644 --- a/vms/platformvm/state/mock_diff.go +++ b/vms/platformvm/state/mock_diff.go @@ -367,6 +367,20 @@ func (mr *MockDiffMockRecorder) GetPendingValidator(subnetID, nodeID any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPendingValidator", reflect.TypeOf((*MockDiff)(nil).GetPendingValidator), subnetID, nodeID) } +// GetSoVExcess mocks base method. +func (m *MockDiff) GetSoVExcess() gas.Gas { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSoVExcess") + ret0, _ := ret[0].(gas.Gas) + return ret0 +} + +// GetSoVExcess indicates an expected call of GetSoVExcess. +func (mr *MockDiffMockRecorder) GetSoVExcess() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSoVExcess", reflect.TypeOf((*MockDiff)(nil).GetSoVExcess)) +} + // GetSubnetConversion mocks base method. func (m *MockDiff) GetSubnetConversion(subnetID ids.ID) (SubnetConversion, error) { m.ctrl.T.Helper() @@ -586,6 +600,18 @@ func (mr *MockDiffMockRecorder) SetFeeState(f any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFeeState", reflect.TypeOf((*MockDiff)(nil).SetFeeState), f) } +// SetSoVExcess mocks base method. +func (m *MockDiff) SetSoVExcess(e gas.Gas) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetSoVExcess", e) +} + +// SetSoVExcess indicates an expected call of SetSoVExcess. +func (mr *MockDiffMockRecorder) SetSoVExcess(e any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSoVExcess", reflect.TypeOf((*MockDiff)(nil).SetSoVExcess), e) +} + // SetSubnetConversion mocks base method. func (m *MockDiff) SetSubnetConversion(subnetID ids.ID, c SubnetConversion) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index a1759398257..cb05f54fc6f 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -527,6 +527,20 @@ func (mr *MockStateMockRecorder) GetRewardUTXOs(txID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRewardUTXOs", reflect.TypeOf((*MockState)(nil).GetRewardUTXOs), txID) } +// GetSoVExcess mocks base method. +func (m *MockState) GetSoVExcess() gas.Gas { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSoVExcess") + ret0, _ := ret[0].(gas.Gas) + return ret0 +} + +// GetSoVExcess indicates an expected call of GetSoVExcess. +func (mr *MockStateMockRecorder) GetSoVExcess() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSoVExcess", reflect.TypeOf((*MockState)(nil).GetSoVExcess)) +} + // GetStartTime mocks base method. func (m *MockState) GetStartTime(nodeID ids.NodeID) (time.Time, error) { m.ctrl.T.Helper() @@ -845,6 +859,18 @@ func (mr *MockStateMockRecorder) SetLastAccepted(blkID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastAccepted", reflect.TypeOf((*MockState)(nil).SetLastAccepted), blkID) } +// SetSoVExcess mocks base method. +func (m *MockState) SetSoVExcess(e gas.Gas) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetSoVExcess", e) +} + +// SetSoVExcess indicates an expected call of SetSoVExcess. +func (mr *MockStateMockRecorder) SetSoVExcess(e any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSoVExcess", reflect.TypeOf((*MockState)(nil).SetSoVExcess), e) +} + // SetSubnetConversion mocks base method. func (m *MockState) SetSubnetConversion(subnetID ids.ID, c SubnetConversion) { m.ctrl.T.Helper() diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index 58c2056570b..53b109c2324 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -87,6 +87,7 @@ var ( TimestampKey = []byte("timestamp") FeeStateKey = []byte("fee state") + SoVExcessKey = []byte("sov excess") AccruedFeesKey = []byte("accrued fees") CurrentSupplyKey = []byte("current supply") LastAcceptedKey = []byte("last accepted") @@ -110,6 +111,9 @@ type Chain interface { GetFeeState() gas.State SetFeeState(f gas.State) + GetSoVExcess() gas.Gas + SetSoVExcess(e gas.Gas) + GetAccruedFees() uint64 SetAccruedFees(f uint64) @@ -289,6 +293,7 @@ type stateBlk struct { * |-- blocksReindexedKey -> nil * |-- timestampKey -> timestamp * |-- feeStateKey -> feeState + * |-- sovExcessKey -> sovExcess * |-- accruedFeesKey -> accruedFees * |-- currentSupplyKey -> currentSupply * |-- lastAcceptedKey -> lastAccepted @@ -386,6 +391,7 @@ type state struct { // The persisted fields represent the current database value timestamp, persistedTimestamp time.Time feeState, persistedFeeState gas.State + sovExcess, persistedSOVExcess gas.Gas accruedFees, persistedAccruedFees uint64 currentSupply, persistedCurrentSupply uint64 // [lastAccepted] is the most recently accepted block. @@ -1091,6 +1097,14 @@ func (s *state) SetFeeState(feeState gas.State) { s.feeState = feeState } +func (s *state) GetSoVExcess() gas.Gas { + return s.sovExcess +} + +func (s *state) SetSoVExcess(e gas.Gas) { + s.sovExcess = e +} + func (s *state) GetAccruedFees() uint64 { return s.accruedFees } @@ -1391,6 +1405,13 @@ func (s *state) loadMetadata() error { s.persistedFeeState = feeState s.SetFeeState(feeState) + sovExcess, err := database.WithDefault(database.GetUInt64, s.singletonDB, SoVExcessKey, 0) + if err != nil { + return err + } + s.persistedSOVExcess = gas.Gas(sovExcess) + s.SetSoVExcess(gas.Gas(sovExcess)) + accruedFees, err := database.WithDefault(database.GetUInt64, s.singletonDB, AccruedFeesKey, 0) if err != nil { return err @@ -2439,6 +2460,12 @@ func (s *state) writeMetadata() error { } s.persistedFeeState = s.feeState } + if s.sovExcess != s.persistedSOVExcess { + if err := database.PutUInt64(s.singletonDB, SoVExcessKey, uint64(s.sovExcess)); err != nil { + return fmt.Errorf("failed to write sov excess: %w", err) + } + s.persistedSOVExcess = s.sovExcess + } if s.accruedFees != s.persistedAccruedFees { if err := database.PutUInt64(s.singletonDB, AccruedFeesKey, s.accruedFees); err != nil { return fmt.Errorf("failed to write accrued fees: %w", err) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index d2912950043..6204540bd61 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -1474,6 +1474,22 @@ func TestStateFeeStateCommitAndLoad(t *testing.T) { require.Equal(expectedFeeState, s.GetFeeState()) } +// Verify that committing the state writes the sov excess to the database and +// that loading the state fetches the sov excess from the database. +func TestStateSoVExcessCommitAndLoad(t *testing.T) { + require := require.New(t) + + db := memdb.New() + s := newTestState(t, db) + + const expectedSoVExcess gas.Gas = 10 + s.SetSoVExcess(expectedSoVExcess) + require.NoError(s.Commit()) + + s = newTestState(t, db) + require.Equal(expectedSoVExcess, s.GetSoVExcess()) +} + // Verify that committing the state writes the accrued fees to the database and // that loading the state fetches the accrued fees from the database. func TestStateAccruedFeesCommitAndLoad(t *testing.T) { From 28152cb8e0662603e9dc3e36fb51ede159ae7c14 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Wed, 23 Oct 2024 16:04:57 -0400 Subject: [PATCH 13/20] Remove deprecated X-chain pubsub server (#3490) --- go.mod | 2 +- pubsub/bloom/filter.go | 51 -------- pubsub/bloom/filter_test.go | 32 ----- pubsub/bloom/map_filter.go | 35 ------ pubsub/connection.go | 214 -------------------------------- pubsub/connections.go | 51 -------- pubsub/filter_param.go | 87 ------------- pubsub/filter_test.go | 77 ------------ pubsub/filterer.go | 8 -- pubsub/messages.go | 77 ------------ pubsub/server.go | 127 ------------------- vms/avm/pubsub_filterer.go | 44 ------- vms/avm/pubsub_filterer_test.go | 50 -------- vms/avm/service.md | 119 ------------------ vms/avm/vm.go | 7 -- 15 files changed, 1 insertion(+), 980 deletions(-) delete mode 100644 pubsub/bloom/filter.go delete mode 100644 pubsub/bloom/filter_test.go delete mode 100644 pubsub/bloom/map_filter.go delete mode 100644 pubsub/connection.go delete mode 100644 pubsub/connections.go delete mode 100644 pubsub/filter_param.go delete mode 100644 pubsub/filter_test.go delete mode 100644 pubsub/filterer.go delete mode 100644 pubsub/messages.go delete mode 100644 pubsub/server.go delete mode 100644 vms/avm/pubsub_filterer.go delete mode 100644 vms/avm/pubsub_filterer_test.go diff --git a/go.mod b/go.mod index 4cea2f92a5a..3900e0bbfc3 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,6 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/rpc v1.2.0 - github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/holiman/uint256 v1.2.4 github.com/huin/goupnp v1.3.0 @@ -121,6 +120,7 @@ require ( github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gofuzz v1.2.0 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect github.com/hashicorp/go-bexpr v0.1.10 // indirect github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d // indirect diff --git a/pubsub/bloom/filter.go b/pubsub/bloom/filter.go deleted file mode 100644 index b0d023b51f1..00000000000 --- a/pubsub/bloom/filter.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package bloom - -import ( - "errors" - - "github.com/ava-labs/avalanchego/utils/bloom" -) - -const bytesPerHash = 8 - -var ( - _ Filter = (*filter)(nil) - - errMaxBytes = errors.New("too large") -) - -type Filter interface { - // Add adds to filter, assumed thread safe - Add(...[]byte) - - // Check checks filter, assumed thread safe - Check([]byte) bool -} - -func New(maxN int, p float64, maxBytes int) (Filter, error) { - numHashes, numEntries := bloom.OptimalParameters(maxN, p) - if neededBytes := 1 + numHashes*bytesPerHash + numEntries; neededBytes > maxBytes { - return nil, errMaxBytes - } - f, err := bloom.New(numHashes, numEntries) - return &filter{ - filter: f, - }, err -} - -type filter struct { - filter *bloom.Filter -} - -func (f *filter) Add(bl ...[]byte) { - for _, b := range bl { - bloom.Add(f.filter, b, nil) - } -} - -func (f *filter) Check(b []byte) bool { - return bloom.Contains(f.filter, b, nil) -} diff --git a/pubsub/bloom/filter_test.go b/pubsub/bloom/filter_test.go deleted file mode 100644 index 3b2c4b71a59..00000000000 --- a/pubsub/bloom/filter_test.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package bloom - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/utils/units" -) - -func TestNew(t *testing.T) { - var ( - require = require.New(t) - maxN = 10000 - p = 0.1 - maxBytes = 1 * units.MiB // 1 MiB - ) - f, err := New(maxN, p, maxBytes) - require.NoError(err) - require.NotNil(f) - - f.Add([]byte("hello")) - - checked := f.Check([]byte("hello")) - require.True(checked, "should have contained the key") - - checked = f.Check([]byte("bye")) - require.False(checked, "shouldn't have contained the key") -} diff --git a/pubsub/bloom/map_filter.go b/pubsub/bloom/map_filter.go deleted file mode 100644 index d0edcbe88fd..00000000000 --- a/pubsub/bloom/map_filter.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package bloom - -import ( - "sync" - - "github.com/ava-labs/avalanchego/utils/set" -) - -type mapFilter struct { - lock sync.RWMutex - values set.Set[string] -} - -func NewMap() Filter { - return &mapFilter{} -} - -func (m *mapFilter) Add(bl ...[]byte) { - m.lock.Lock() - defer m.lock.Unlock() - - for _, b := range bl { - m.values.Add(string(b)) - } -} - -func (m *mapFilter) Check(b []byte) bool { - m.lock.RLock() - defer m.lock.RUnlock() - - return m.values.Contains(string(b)) -} diff --git a/pubsub/connection.go b/pubsub/connection.go deleted file mode 100644 index 31d493355cd..00000000000 --- a/pubsub/connection.go +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "encoding/json" - "errors" - "fmt" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "go.uber.org/zap" - - "github.com/ava-labs/avalanchego/pubsub/bloom" -) - -var ( - ErrFilterNotInitialized = errors.New("filter not initialized") - ErrAddressLimit = errors.New("address limit exceeded") - ErrInvalidFilterParam = errors.New("invalid bloom filter params") - ErrInvalidCommand = errors.New("invalid command") - _ Filter = (*connection)(nil) -) - -type Filter interface { - Check(addr []byte) bool -} - -// connection is a representation of the websocket connection. -type connection struct { - s *Server - - // The websocket connection. - conn *websocket.Conn - - // Buffered channel of outbound messages. - send chan interface{} - - fp *FilterParam - - active uint32 -} - -func (c *connection) Check(addr []byte) bool { - return c.fp.Check(addr) -} - -func (c *connection) isActive() bool { - active := atomic.LoadUint32(&c.active) - return active != 0 -} - -func (c *connection) deactivate() { - atomic.StoreUint32(&c.active, 0) -} - -func (c *connection) Send(msg interface{}) bool { - if !c.isActive() { - return false - } - select { - case c.send <- msg: - return true - default: - } - return false -} - -// readPump pumps messages from the websocket connection to the hub. -// -// The application runs readPump in a per-connection goroutine. The application -// ensures that there is at most one reader on a connection by executing all -// reads from this goroutine. -func (c *connection) readPump() { - defer func() { - c.deactivate() - c.s.removeConnection(c) - - // close is called by both the writePump and the readPump so one of them - // will always error - _ = c.conn.Close() - }() - - c.conn.SetReadLimit(maxMessageSize) - // SetReadDeadline returns an error if the connection is corrupted - if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { - return - } - c.conn.SetPongHandler(func(string) error { - return c.conn.SetReadDeadline(time.Now().Add(pongWait)) - }) - - for { - err := c.readMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - c.s.log.Debug("unexpected close in websockets", - zap.Error(err), - ) - } - break - } - } -} - -// writePump pumps messages from the hub to the websocket connection. -// -// A goroutine running writePump is started for each connection. The -// application ensures that there is at most one writer to a connection by -// executing all writes from this goroutine. -func (c *connection) writePump() { - ticker := time.NewTicker(pingPeriod) - defer func() { - c.deactivate() - ticker.Stop() - c.s.removeConnection(c) - - // close is called by both the writePump and the readPump so one of them - // will always error - _ = c.conn.Close() - }() - for { - select { - case message, ok := <-c.send: - if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { - c.s.log.Debug("closing the connection", - zap.String("reason", "failed to set the write deadline"), - zap.Error(err), - ) - return - } - if !ok { - // The hub closed the channel. Attempt to close the connection - // gracefully. - _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) - return - } - - if err := c.conn.WriteJSON(message); err != nil { - return - } - case <-ticker.C: - if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { - c.s.log.Debug("closing the connection", - zap.String("reason", "failed to set the write deadline"), - zap.Error(err), - ) - return - } - if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return - } - } - } -} - -func (c *connection) readMessage() error { - _, r, err := c.conn.NextReader() - if err != nil { - return err - } - cmd := &Command{} - err = json.NewDecoder(r).Decode(cmd) - if err != nil { - return err - } - - switch { - case cmd.NewBloom != nil: - err = c.handleNewBloom(cmd.NewBloom) - case cmd.NewSet != nil: - c.handleNewSet(cmd.NewSet) - case cmd.AddAddresses != nil: - err = c.handleAddAddresses(cmd.AddAddresses) - default: - err = ErrInvalidCommand - } - if err != nil { - c.Send(&errorMsg{ - Error: err.Error(), - }) - } - return err -} - -func (c *connection) handleNewBloom(cmd *NewBloom) error { - if !cmd.IsParamsValid() { - return ErrInvalidFilterParam - } - filter, err := bloom.New(int(cmd.MaxElements), float64(cmd.CollisionProb), MaxBytes) - if err != nil { - return fmt.Errorf("bloom filter creation failed %w", err) - } - c.fp.SetFilter(filter) - return nil -} - -func (c *connection) handleNewSet(_ *NewSet) { - c.fp.NewSet() -} - -func (c *connection) handleAddAddresses(cmd *AddAddresses) error { - if err := cmd.parseAddresses(); err != nil { - return fmt.Errorf("address parse failed %w", err) - } - err := c.fp.Add(cmd.addressIds...) - if err != nil { - return fmt.Errorf("address append failed %w", err) - } - c.s.subscribedConnections.Add(c) - return nil -} diff --git a/pubsub/connections.go b/pubsub/connections.go deleted file mode 100644 index 25d35ac8cd8..00000000000 --- a/pubsub/connections.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "sync" - - "github.com/ava-labs/avalanchego/utils/set" -) - -type connections struct { - lock sync.RWMutex - conns set.Set[*connection] - connsList []Filter -} - -func newConnections() *connections { - return &connections{} -} - -func (c *connections) Conns() []Filter { - c.lock.RLock() - defer c.lock.RUnlock() - - return append([]Filter{}, c.connsList...) -} - -func (c *connections) Remove(conn *connection) { - c.lock.Lock() - defer c.lock.Unlock() - - c.conns.Remove(conn) - c.createConnsList() -} - -func (c *connections) Add(conn *connection) { - c.lock.Lock() - defer c.lock.Unlock() - - c.conns.Add(conn) - c.createConnsList() -} - -func (c *connections) createConnsList() { - resp := make([]Filter, 0, len(c.conns)) - for c := range c.conns { - resp = append(resp, c) - } - c.connsList = resp -} diff --git a/pubsub/filter_param.go b/pubsub/filter_param.go deleted file mode 100644 index 5fd80a2ad70..00000000000 --- a/pubsub/filter_param.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "sync" - - "github.com/ava-labs/avalanchego/pubsub/bloom" - "github.com/ava-labs/avalanchego/utils/set" -) - -type FilterParam struct { - lock sync.RWMutex - set set.Set[string] - filter bloom.Filter -} - -func NewFilterParam() *FilterParam { - return &FilterParam{ - set: set.Set[string]{}, - } -} - -func (f *FilterParam) NewSet() { - f.lock.Lock() - defer f.lock.Unlock() - - f.set = set.Set[string]{} - f.filter = nil -} - -func (f *FilterParam) Filter() bloom.Filter { - f.lock.RLock() - defer f.lock.RUnlock() - - return f.filter -} - -func (f *FilterParam) SetFilter(filter bloom.Filter) bloom.Filter { - f.lock.Lock() - defer f.lock.Unlock() - - f.filter = filter - f.set = nil - return f.filter -} - -func (f *FilterParam) Check(addr []byte) bool { - f.lock.RLock() - defer f.lock.RUnlock() - - if f.filter != nil && f.filter.Check(addr) { - return true - } - return f.set.Contains(string(addr)) -} - -func (f *FilterParam) Add(bl ...[]byte) error { - filter := f.Filter() - if filter != nil { - filter.Add(bl...) - return nil - } - - f.lock.Lock() - defer f.lock.Unlock() - - if f.set == nil { - return ErrFilterNotInitialized - } - - if len(f.set)+len(bl) > MaxAddresses { - return ErrAddressLimit - } - for _, b := range bl { - f.set.Add(string(b)) - } - return nil -} - -func (f *FilterParam) Len() int { - f.lock.RLock() - defer f.lock.RUnlock() - - return len(f.set) -} diff --git a/pubsub/filter_test.go b/pubsub/filter_test.go deleted file mode 100644 index 3b47a38e023..00000000000 --- a/pubsub/filter_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/api" - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/pubsub/bloom" - "github.com/ava-labs/avalanchego/utils/constants" - "github.com/ava-labs/avalanchego/utils/formatting/address" -) - -func TestAddAddressesParseAddresses(t *testing.T) { - require := require.New(t) - - chainAlias := "X" - hrp := constants.GetHRP(5) - - addrID := ids.ShortID{1} - addrStr, err := address.Format(chainAlias, hrp, addrID[:]) - require.NoError(err) - - msg := &AddAddresses{JSONAddresses: api.JSONAddresses{ - Addresses: []string{ - addrStr, - }, - }} - - require.NoError(msg.parseAddresses()) - - require.Len(msg.addressIds, 1) - require.Equal(addrID[:], msg.addressIds[0]) -} - -func TestFilterParamUpdateMulti(t *testing.T) { - require := require.New(t) - - fp := NewFilterParam() - - addr1 := []byte("abc") - addr2 := []byte("def") - addr3 := []byte("xyz") - - require.NoError(fp.Add(addr1, addr2, addr3)) - require.Len(fp.set, 3) - require.Contains(fp.set, string(addr1)) - require.Contains(fp.set, string(addr2)) - require.Contains(fp.set, string(addr3)) -} - -func TestFilterParam(t *testing.T) { - require := require.New(t) - - mapFilter := bloom.NewMap() - - fp := NewFilterParam() - fp.SetFilter(mapFilter) - - addr := ids.GenerateTestShortID() - require.NoError(fp.Add(addr[:])) - require.True(fp.Check(addr[:])) - delete(fp.set, string(addr[:])) - - mapFilter.Add(addr[:]) - require.True(fp.Check(addr[:])) - require.False(fp.Check([]byte("bye"))) -} - -func TestNewBloom(t *testing.T) { - cm := &NewBloom{} - require.False(t, cm.IsParamsValid()) -} diff --git a/pubsub/filterer.go b/pubsub/filterer.go deleted file mode 100644 index 3ec2910a9c4..00000000000 --- a/pubsub/filterer.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -type Filterer interface { - Filter(connections []Filter) ([]bool, interface{}) -} diff --git a/pubsub/messages.go b/pubsub/messages.go deleted file mode 100644 index ec41af813cd..00000000000 --- a/pubsub/messages.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "github.com/ava-labs/avalanchego/api" - "github.com/ava-labs/avalanchego/utils/formatting/address" - "github.com/ava-labs/avalanchego/utils/json" -) - -// NewBloom command for a new bloom filter -// -// Deprecated: The pubsub server is deprecated. -type NewBloom struct { - // MaxElements size of bloom filter - MaxElements json.Uint64 `json:"maxElements"` - // CollisionProb expected error rate of filter - CollisionProb json.Float64 `json:"collisionProb"` -} - -// NewSet command for a new map set -// -// Deprecated: The pubsub server is deprecated. -type NewSet struct{} - -// AddAddresses command to add addresses -// -// Deprecated: The pubsub server is deprecated. -type AddAddresses struct { - api.JSONAddresses - - // addressIds array of addresses, kept as a [][]byte for use in the bloom filter - addressIds [][]byte -} - -// Command execution command -// -// Deprecated: The pubsub server is deprecated. -type Command struct { - NewBloom *NewBloom `json:"newBloom,omitempty"` - NewSet *NewSet `json:"newSet,omitempty"` - AddAddresses *AddAddresses `json:"addAddresses,omitempty"` -} - -func (c *Command) String() string { - switch { - case c.NewBloom != nil: - return "newBloom" - case c.NewSet != nil: - return "newSet" - case c.AddAddresses != nil: - return "addAddresses" - default: - return "unknown" - } -} - -func (c *NewBloom) IsParamsValid() bool { - p := float64(c.CollisionProb) - return c.MaxElements > 0 && 0 < p && p <= 1 -} - -// parseAddresses converts the bech32 addresses to their byte format. -func (c *AddAddresses) parseAddresses() error { - if c.addressIds == nil { - c.addressIds = make([][]byte, len(c.Addresses)) - } - for i, addrStr := range c.Addresses { - _, _, addrBytes, err := address.Parse(addrStr) - if err != nil { - return err - } - c.addressIds[i] = addrBytes - } - return nil -} diff --git a/pubsub/server.go b/pubsub/server.go deleted file mode 100644 index b07dea89b34..00000000000 --- a/pubsub/server.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package pubsub - -import ( - "net/http" - "sync" - "time" - - "github.com/gorilla/websocket" - "go.uber.org/zap" - - "github.com/ava-labs/avalanchego/utils/logging" - "github.com/ava-labs/avalanchego/utils/set" - "github.com/ava-labs/avalanchego/utils/units" -) - -const ( - // Size of the ws read buffer - readBufferSize = units.KiB - - // Size of the ws write buffer - writeBufferSize = units.KiB - - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 - - // Maximum message size allowed from peer. - maxMessageSize = 10 * units.KiB // bytes - - // Maximum number of pending messages to send to a peer. - maxPendingMessages = 1024 // messages - - // MaxBytes the max number of bytes for a filter - MaxBytes = 1 * units.MiB - - // MaxAddresses the max number of addresses allowed - MaxAddresses = 10000 -) - -type errorMsg struct { - Error string `json:"error"` -} - -var upgrader = websocket.Upgrader{ - ReadBufferSize: readBufferSize, - WriteBufferSize: writeBufferSize, - CheckOrigin: func(*http.Request) bool { - return true - }, -} - -// Server maintains the set of active clients and sends messages to the clients. -type Server struct { - log logging.Logger - lock sync.RWMutex - // conns a list of all our connections - conns set.Set[*connection] - // subscribedConnections the connections that have activated subscriptions - subscribedConnections *connections -} - -// Deprecated: The pubsub server is deprecated. -func New(log logging.Logger) *Server { - return &Server{ - log: log, - subscribedConnections: newConnections(), - } -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - wsConn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - s.log.Debug("failed to upgrade", - zap.Error(err), - ) - return - } - conn := &connection{ - s: s, - conn: wsConn, - send: make(chan interface{}, maxPendingMessages), - fp: NewFilterParam(), - active: 1, - } - s.addConnection(conn) -} - -func (s *Server) Publish(parser Filterer) { - conns := s.subscribedConnections.Conns() - toNotify, msg := parser.Filter(conns) - for i, shouldNotify := range toNotify { - if !shouldNotify { - continue - } - conn := conns[i].(*connection) - if !conn.Send(msg) { - s.log.Verbo("dropping message to subscribed connection due to too many pending messages") - } - } -} - -func (s *Server) addConnection(conn *connection) { - s.lock.Lock() - defer s.lock.Unlock() - - s.conns.Add(conn) - - go conn.writePump() - go conn.readPump() -} - -func (s *Server) removeConnection(conn *connection) { - s.subscribedConnections.Remove(conn) - - s.lock.Lock() - defer s.lock.Unlock() - - s.conns.Remove(conn) -} diff --git a/vms/avm/pubsub_filterer.go b/vms/avm/pubsub_filterer.go deleted file mode 100644 index caf0ba34839..00000000000 --- a/vms/avm/pubsub_filterer.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avm - -import ( - "github.com/ava-labs/avalanchego/api" - "github.com/ava-labs/avalanchego/pubsub" - "github.com/ava-labs/avalanchego/vms/avm/txs" - "github.com/ava-labs/avalanchego/vms/components/avax" -) - -var _ pubsub.Filterer = (*connector)(nil) - -type connector struct { - tx *txs.Tx -} - -func NewPubSubFilterer(tx *txs.Tx) pubsub.Filterer { - return &connector{tx: tx} -} - -// Apply the filter on the addresses. -func (f *connector) Filter(filters []pubsub.Filter) ([]bool, interface{}) { - resp := make([]bool, len(filters)) - for _, utxo := range f.tx.UTXOs() { - addressable, ok := utxo.Out.(avax.Addressable) - if !ok { - continue - } - - for _, address := range addressable.Addresses() { - for i, c := range filters { - if resp[i] { - continue - } - resp[i] = c.Check(address) - } - } - } - return resp, api.JSONTxID{ - TxID: f.tx.ID(), - } -} diff --git a/vms/avm/pubsub_filterer_test.go b/vms/avm/pubsub_filterer_test.go deleted file mode 100644 index 0059b2218e3..00000000000 --- a/vms/avm/pubsub_filterer_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package avm - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/pubsub" - "github.com/ava-labs/avalanchego/vms/avm/txs" - "github.com/ava-labs/avalanchego/vms/components/avax" - "github.com/ava-labs/avalanchego/vms/secp256k1fx" -) - -type mockFilter struct { - addr []byte -} - -func (f *mockFilter) Check(addr []byte) bool { - return bytes.Equal(addr, f.addr) -} - -func TestFilter(t *testing.T) { - require := require.New(t) - - addrID := ids.ShortID{1} - tx := txs.Tx{Unsigned: &txs.BaseTx{BaseTx: avax.BaseTx{ - Outs: []*avax.TransferableOutput{ - { - Out: &secp256k1fx.TransferOutput{ - OutputOwners: secp256k1fx.OutputOwners{ - Addrs: []ids.ShortID{addrID}, - }, - }, - }, - }, - }}} - addrBytes := addrID[:] - - fp := pubsub.NewFilterParam() - require.NoError(fp.Add(addrBytes)) - - parser := NewPubSubFilterer(&tx) - fr, _ := parser.Filter([]pubsub.Filter{&mockFilter{addr: addrBytes}}) - require.Equal([]bool{true}, fr) -} diff --git a/vms/avm/service.md b/vms/avm/service.md index dfba13b05f0..ac455879433 100644 --- a/vms/avm/service.md +++ b/vms/avm/service.md @@ -2198,122 +2198,3 @@ curl -X POST --data '{ } } ``` - -### Events - -Listen for transactions on a specified address. - -This call is made to the events API endpoint: - -`/ext/bc/X/events` - -:::caution - -Endpoint deprecated as of [**v1.9.12**](https://github.com/ava-labs/avalanchego/releases/tag/v1.9.12). - -::: - -#### **Golang Example** - -```go -package main - -import ( - "encoding/json" - "log" - "net" - "net/http" - "sync" - - "github.com/ava-labs/avalanchego/api" - "github.com/ava-labs/avalanchego/pubsub" - "github.com/gorilla/websocket" -) - -func main() { - dialer := websocket.Dialer{ - NetDial: func(netw, addr string) (net.Conn, error) { - return net.Dial(netw, addr) - }, - } - - httpHeader := http.Header{} - conn, _, err := dialer.Dial("ws://localhost:9650/ext/bc/X/events", httpHeader) - if err != nil { - panic(err) - } - - waitGroup := &sync.WaitGroup{} - waitGroup.Add(1) - - readMsg := func() { - defer waitGroup.Done() - - for { - mt, msg, err := conn.ReadMessage() - if err != nil { - log.Println(err) - return - } - switch mt { - case websocket.TextMessage: - log.Println(string(msg)) - default: - log.Println(mt, string(msg)) - } - } - } - - go readMsg() - - cmd := &pubsub.Command{NewSet: &pubsub.NewSet{}} - cmdmsg, err := json.Marshal(cmd) - if err != nil { - panic(err) - } - err = conn.WriteMessage(websocket.TextMessage, cmdmsg) - if err != nil { - panic(err) - } - - var addresses []string - addresses = append(addresses, " X-fuji....") - cmd = &pubsub.Command{AddAddresses: &pubsub.AddAddresses{JSONAddresses: api.JSONAddresses{Addresses: addresses}}} - cmdmsg, err = json.Marshal(cmd) - if err != nil { - panic(err) - } - - err = conn.WriteMessage(websocket.TextMessage, cmdmsg) - if err != nil { - panic(err) - } - - waitGroup.Wait() -} -``` - -**Operations:** - -| Command | Description | Example | Arguments | -| :--------------- | :--------------------------- | :------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------- | -| **NewSet** | create a new address map set | `{"newSet":{}}` | | -| **NewBloom** | create a new bloom set. | `{"newBloom":{"maxElements":"1000","collisionProb":"0.0100"}}` | `maxElements` - number of elements in filter must be > 0 `collisionProb` - allowed collision probability must be > 0 and <= 1 | -| **AddAddresses** | add an address to the set | `{"addAddresses":{"addresses":\["X-fuji..."\]}}` | addresses - list of addresses to match | - -Calling **NewSet** or **NewBloom** resets the filter, and must be followed with **AddAddresses**. -**AddAddresses** can be called multiple times. - -**Set details:** - -- **NewSet** performs absolute address matches, if the address is in the set you will be sent the - transaction. -- **NewBloom** [Bloom filtering](https://en.wikipedia.org/wiki/Bloom_filter) can produce false - positives, but can allow a greater number of addresses to be filtered. If the addresses is in the - filter, you will be sent the transaction. - -**Example Response:** - -```json -2021/05/11 15:59:35 {"txID":"22HWKHrREyXyAiDnVmGp3TQQ79tHSSVxA9h26VfDEzoxvwveyk"} -``` diff --git a/vms/avm/vm.go b/vms/avm/vm.go index c9170ba882f..1026ba7d3de 100644 --- a/vms/avm/vm.go +++ b/vms/avm/vm.go @@ -20,7 +20,6 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/versiondb" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/avalanchego/pubsub" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/consensus/snowstorm" @@ -84,8 +83,6 @@ type VM struct { parser block.Parser - pubsub *pubsub.Server - appSender common.AppSender // State management @@ -195,8 +192,6 @@ func (vm *VM) Initialize( vm.db = versiondb.New(db) vm.assetToFxCache = &cache.LRU[ids.ID, set.Bits64]{Size: assetToFxCacheSize} - vm.pubsub = pubsub.New(ctx.Log) - typedFxs := make([]extensions.Fx, len(fxs)) vm.fxs = make([]*extensions.ParsedFx, len(fxs)) for i, fxContainer := range fxs { @@ -353,7 +348,6 @@ func (vm *VM) CreateHandlers(context.Context) (map[string]http.Handler, error) { return map[string]http.Handler{ "": rpcServer, "/wallet": walletServer, - "/events": vm.pubsub, }, err } @@ -681,7 +675,6 @@ func (vm *VM) onAccept(tx *txs.Tx) error { return fmt.Errorf("error indexing tx: %w", err) } - vm.pubsub.Publish(NewPubSubFilterer(tx)) vm.walletService.decided(txID) return nil } From 05295f050a410388d920a1fb6df4a40e3c4f7b30 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 09:52:06 -0400 Subject: [PATCH 14/20] Update SoV struct to align with latest ACP-77 spec (#3492) --- vms/platformvm/state/subnet_only_validator.go | 74 ++++++--- .../state/subnet_only_validator_test.go | 148 ++++++++++++++---- 2 files changed, 172 insertions(+), 50 deletions(-) diff --git a/vms/platformvm/state/subnet_only_validator.go b/vms/platformvm/state/subnet_only_validator.go index 5af028314e6..1da34a50299 100644 --- a/vms/platformvm/state/subnet_only_validator.go +++ b/vms/platformvm/state/subnet_only_validator.go @@ -4,17 +4,28 @@ package state import ( + "bytes" + "errors" "fmt" "github.com/google/btree" "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/vms/platformvm/block" ) -var _ btree.LessFunc[*SubnetOnlyValidator] = (*SubnetOnlyValidator).Less +var ( + _ btree.LessFunc[SubnetOnlyValidator] = SubnetOnlyValidator.Less + _ utils.Sortable[SubnetOnlyValidator] = SubnetOnlyValidator{} + ErrMutatedSubnetOnlyValidator = errors.New("subnet only validator contains mutated constant fields") +) + +// SubnetOnlyValidator defines an ACP-77 validator. For a given ValidationID, it +// is expected for SubnetID, NodeID, PublicKey, RemainingBalanceOwner, and +// StartTime to be constant. type SubnetOnlyValidator struct { // ValidationID is not serialized because it is used as the key in the // database, so it doesn't need to be stored in the value. @@ -27,6 +38,14 @@ type SubnetOnlyValidator struct { // guaranteed to be populated. PublicKey []byte `serialize:"true"` + // RemainingBalanceOwner is the owner that will be used when returning the + // balance of the validator after removing accrued fees. + RemainingBalanceOwner []byte `serialize:"true"` + + // DeactivationOwner is the owner that can manually deactivate the + // validator. + DeactivationOwner []byte `serialize:"true"` + // StartTime is the unix timestamp, in seconds, when this validator was // added to the set. StartTime uint64 `serialize:"true"` @@ -46,44 +65,59 @@ type SubnetOnlyValidator struct { // accrue before this validator must be deactivated. It is equal to the // amount of fees this validator is willing to pay plus the amount of // globally accumulated fees when this validator started validating. + // + // If this value is 0, the validator is inactive. EndAccumulatedFee uint64 `serialize:"true"` } -// Less determines a canonical ordering of *SubnetOnlyValidators based on their -// EndAccumulatedFees and ValidationIDs. -// -// Returns true if: -// -// 1. This validator has a lower EndAccumulatedFee than the other. -// 2. This validator has an equal EndAccumulatedFee to the other and has a -// lexicographically lower ValidationID. -func (v *SubnetOnlyValidator) Less(o *SubnetOnlyValidator) bool { +func (v SubnetOnlyValidator) Less(o SubnetOnlyValidator) bool { + return v.Compare(o) == -1 +} + +// Compare determines a canonical ordering of SubnetOnlyValidators based on +// their EndAccumulatedFees and ValidationIDs. Lower EndAccumulatedFees result +// in an earlier ordering. +func (v SubnetOnlyValidator) Compare(o SubnetOnlyValidator) int { switch { case v.EndAccumulatedFee < o.EndAccumulatedFee: - return true + return -1 case o.EndAccumulatedFee < v.EndAccumulatedFee: - return false + return 1 default: - return v.ValidationID.Compare(o.ValidationID) == -1 + return v.ValidationID.Compare(o.ValidationID) + } +} + +// constantsAreUnmodified returns true if the constants of this validator have +// not been modified compared to the other validator. +func (v SubnetOnlyValidator) constantsAreUnmodified(o SubnetOnlyValidator) bool { + if v.ValidationID != o.ValidationID { + return true } + return v.SubnetID == o.SubnetID && + v.NodeID == o.NodeID && + bytes.Equal(v.PublicKey, o.PublicKey) && + bytes.Equal(v.RemainingBalanceOwner, o.RemainingBalanceOwner) && + bytes.Equal(v.DeactivationOwner, o.DeactivationOwner) && + v.StartTime == o.StartTime } -func getSubnetOnlyValidator(db database.KeyValueReader, validationID ids.ID) (*SubnetOnlyValidator, error) { +func getSubnetOnlyValidator(db database.KeyValueReader, validationID ids.ID) (SubnetOnlyValidator, error) { bytes, err := db.Get(validationID[:]) if err != nil { - return nil, err + return SubnetOnlyValidator{}, err } - vdr := &SubnetOnlyValidator{ + vdr := SubnetOnlyValidator{ ValidationID: validationID, } - if _, err = block.GenesisCodec.Unmarshal(bytes, vdr); err != nil { - return nil, fmt.Errorf("failed to unmarshal SubnetOnlyValidator: %w", err) + if _, err := block.GenesisCodec.Unmarshal(bytes, &vdr); err != nil { + return SubnetOnlyValidator{}, fmt.Errorf("failed to unmarshal SubnetOnlyValidator: %w", err) } - return vdr, err + return vdr, nil } -func putSubnetOnlyValidator(db database.KeyValueWriter, vdr *SubnetOnlyValidator) error { +func putSubnetOnlyValidator(db database.KeyValueWriter, vdr SubnetOnlyValidator) error { bytes, err := block.GenesisCodec.Marshal(block.CodecVersion, vdr) if err != nil { return fmt.Errorf("failed to marshal SubnetOnlyValidator: %w", err) diff --git a/vms/platformvm/state/subnet_only_validator_test.go b/vms/platformvm/state/subnet_only_validator_test.go index bcbb21e0027..6b6c86520a6 100644 --- a/vms/platformvm/state/subnet_only_validator_test.go +++ b/vms/platformvm/state/subnet_only_validator_test.go @@ -12,88 +12,176 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/crypto/bls" + "github.com/ava-labs/avalanchego/vms/platformvm/block" + "github.com/ava-labs/avalanchego/vms/platformvm/fx" + "github.com/ava-labs/avalanchego/vms/secp256k1fx" ) -func TestSubnetOnlyValidator_Less(t *testing.T) { +func TestSubnetOnlyValidator_Compare(t *testing.T) { tests := []struct { - name string - v *SubnetOnlyValidator - o *SubnetOnlyValidator - equal bool + name string + v SubnetOnlyValidator + o SubnetOnlyValidator + expected int }{ { name: "v.EndAccumulatedFee < o.EndAccumulatedFee", - v: &SubnetOnlyValidator{ + v: SubnetOnlyValidator{ ValidationID: ids.GenerateTestID(), EndAccumulatedFee: 1, }, - o: &SubnetOnlyValidator{ + o: SubnetOnlyValidator{ ValidationID: ids.GenerateTestID(), EndAccumulatedFee: 2, }, - equal: false, + expected: -1, }, { name: "v.EndAccumulatedFee = o.EndAccumulatedFee, v.ValidationID < o.ValidationID", - v: &SubnetOnlyValidator{ + v: SubnetOnlyValidator{ ValidationID: ids.ID{0}, EndAccumulatedFee: 1, }, - o: &SubnetOnlyValidator{ + o: SubnetOnlyValidator{ ValidationID: ids.ID{1}, EndAccumulatedFee: 1, }, - equal: false, + expected: -1, }, { name: "v.EndAccumulatedFee = o.EndAccumulatedFee, v.ValidationID = o.ValidationID", - v: &SubnetOnlyValidator{ + v: SubnetOnlyValidator{ ValidationID: ids.ID{0}, EndAccumulatedFee: 1, }, - o: &SubnetOnlyValidator{ + o: SubnetOnlyValidator{ ValidationID: ids.ID{0}, EndAccumulatedFee: 1, }, - equal: true, + expected: 0, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { require := require.New(t) - less := test.v.Less(test.o) - require.Equal(!test.equal, less) - - greater := test.o.Less(test.v) - require.False(greater) + require.Equal(test.expected, test.v.Compare(test.o)) + require.Equal(-test.expected, test.o.Compare(test.v)) + require.Equal(test.expected == -1, test.v.Less(test.o)) + require.False(test.o.Less(test.v)) }) } } +func TestSubnetOnlyValidator_constantsAreUnmodified(t *testing.T) { + var ( + randomSOV = func() SubnetOnlyValidator { + return SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + PublicKey: utils.RandomBytes(bls.PublicKeyLen), + RemainingBalanceOwner: utils.RandomBytes(32), + DeactivationOwner: utils.RandomBytes(32), + StartTime: rand.Uint64(), // #nosec G404 + } + } + randomizeSOV = func(sov SubnetOnlyValidator) SubnetOnlyValidator { + // Randomize unrelated fields + sov.Weight = rand.Uint64() // #nosec G404 + sov.MinNonce = rand.Uint64() // #nosec G404 + sov.EndAccumulatedFee = rand.Uint64() // #nosec G404 + return sov + } + sov = randomSOV() + ) + + t.Run("equal", func(t *testing.T) { + v := randomizeSOV(sov) + require.True(t, sov.constantsAreUnmodified(v)) + }) + t.Run("everything is different", func(t *testing.T) { + v := randomizeSOV(randomSOV()) + require.True(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different subnetID", func(t *testing.T) { + v := randomizeSOV(sov) + v.SubnetID = ids.GenerateTestID() + require.False(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different nodeID", func(t *testing.T) { + v := randomizeSOV(sov) + v.NodeID = ids.GenerateTestNodeID() + require.False(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different publicKey", func(t *testing.T) { + v := randomizeSOV(sov) + v.PublicKey = utils.RandomBytes(bls.PublicKeyLen) + require.False(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different remainingBalanceOwner", func(t *testing.T) { + v := randomizeSOV(sov) + v.RemainingBalanceOwner = utils.RandomBytes(32) + require.False(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different deactivationOwner", func(t *testing.T) { + v := randomizeSOV(sov) + v.DeactivationOwner = utils.RandomBytes(32) + require.False(t, sov.constantsAreUnmodified(v)) + }) + t.Run("different startTime", func(t *testing.T) { + v := randomizeSOV(sov) + v.StartTime = rand.Uint64() // #nosec G404 + require.False(t, sov.constantsAreUnmodified(v)) + }) +} + func TestSubnetOnlyValidator_DatabaseHelpers(t *testing.T) { require := require.New(t) db := memdb.New() sk, err := bls.NewSecretKey() require.NoError(err) + pk := bls.PublicFromSecretKey(sk) + pkBytes := bls.PublicKeyToUncompressedBytes(pk) + + var remainingBalanceOwner fx.Owner = &secp256k1fx.OutputOwners{ + Threshold: 1, + Addrs: []ids.ShortID{ + ids.GenerateTestShortID(), + }, + } + remainingBalanceOwnerBytes, err := block.GenesisCodec.Marshal(block.CodecVersion, &remainingBalanceOwner) + require.NoError(err) + + var deactivationOwner fx.Owner = &secp256k1fx.OutputOwners{ + Threshold: 1, + Addrs: []ids.ShortID{ + ids.GenerateTestShortID(), + }, + } + deactivationOwnerBytes, err := block.GenesisCodec.Marshal(block.CodecVersion, &deactivationOwner) + require.NoError(err) - vdr := &SubnetOnlyValidator{ - ValidationID: ids.GenerateTestID(), - SubnetID: ids.GenerateTestID(), - NodeID: ids.GenerateTestNodeID(), - PublicKey: bls.PublicKeyToUncompressedBytes(bls.PublicFromSecretKey(sk)), - StartTime: rand.Uint64(), // #nosec G404 - Weight: rand.Uint64(), // #nosec G404 - MinNonce: rand.Uint64(), // #nosec G404 - EndAccumulatedFee: rand.Uint64(), // #nosec G404 + vdr := SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + PublicKey: pkBytes, + RemainingBalanceOwner: remainingBalanceOwnerBytes, + DeactivationOwner: deactivationOwnerBytes, + StartTime: rand.Uint64(), // #nosec G404 + Weight: rand.Uint64(), // #nosec G404 + MinNonce: rand.Uint64(), // #nosec G404 + EndAccumulatedFee: rand.Uint64(), // #nosec G404 } // Validator hasn't been put on disk yet gotVdr, err := getSubnetOnlyValidator(db, vdr.ValidationID) require.ErrorIs(err, database.ErrNotFound) - require.Nil(gotVdr) + require.Zero(gotVdr) // Place the validator on disk require.NoError(putSubnetOnlyValidator(db, vdr)) @@ -109,5 +197,5 @@ func TestSubnetOnlyValidator_DatabaseHelpers(t *testing.T) { // Verify that the validator has been removed from disk gotVdr, err = getSubnetOnlyValidator(db, vdr.ValidationID) require.ErrorIs(err, database.ErrNotFound) - require.Nil(gotVdr) + require.Zero(gotVdr) } From 993d1690b4f214ebd1d5fd28171d65322c9aa8cb Mon Sep 17 00:00:00 2001 From: yacovm Date: Thu, 24 Oct 2024 16:21:30 +0200 Subject: [PATCH 15/20] Register VM and snowman metrics after chain creation (#3489) Signed-off-by: Yacov Manevich Co-authored-by: Stephen Buttolph --- chains/manager.go | 60 ++++++++++++++++++++--------------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/chains/manager.go b/chains/manager.go index 906c6f136df..8d1eb4feea7 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "github.com/ava-labs/avalanchego/api/health" @@ -490,19 +491,6 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c return nil, fmt.Errorf("error while creating chain's log %w", err) } - snowmanMetrics, err := metrics.MakeAndRegister( - m.snowmanGatherer, - primaryAlias, - ) - if err != nil { - return nil, err - } - - vmMetrics, err := m.getOrMakeVMRegisterer(chainParams.VMID, primaryAlias) - if err != nil { - return nil, err - } - ctx := &snow.ConsensusContext{ Context: &snow.Context{ NetworkID: m.NetworkID, @@ -520,7 +508,7 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c Keystore: m.Keystore.NewBlockchainKeyStore(chainParams.ID), SharedMemory: m.AtomicMemory.NewSharedMemory(chainParams.ID), BCLookup: m, - Metrics: vmMetrics, + Metrics: metrics.NewPrefixGatherer(), WarpSigner: warp.NewSigner(m.StakingBLSKey, m.NetworkID, chainParams.ID), @@ -528,7 +516,7 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c ChainDataDir: chainDataDir, }, PrimaryAlias: primaryAlias, - Registerer: snowmanMetrics, + Registerer: prometheus.NewRegistry(), BlockAcceptor: m.BlockAcceptorGroup, TxAcceptor: m.TxAcceptorGroup, VertexAcceptor: m.VertexAcceptorGroup, @@ -601,7 +589,15 @@ func (m *manager) buildChain(chainParams ChainParameters, sb subnets.Subnet) (*c return nil, err } - return chain, nil + vmGatherer, err := m.getOrMakeVMGatherer(chainParams.VMID) + if err != nil { + return nil, err + } + + return chain, errors.Join( + m.snowmanGatherer.Register(primaryAlias, ctx.Registerer), + vmGatherer.Register(primaryAlias, ctx.Metrics), + ) } func (m *manager) AddRegistrant(r Registrant) { @@ -1556,26 +1552,22 @@ func (m *manager) getChainConfig(id ids.ID) (ChainConfig, error) { return ChainConfig{}, nil } -func (m *manager) getOrMakeVMRegisterer(vmID ids.ID, chainAlias string) (metrics.MultiGatherer, error) { +func (m *manager) getOrMakeVMGatherer(vmID ids.ID) (metrics.MultiGatherer, error) { vmGatherer, ok := m.vmGatherer[vmID] - if !ok { - vmName := constants.VMName(vmID) - vmNamespace := metric.AppendNamespace(constants.PlatformName, vmName) - vmGatherer = metrics.NewLabelGatherer(ChainLabel) - err := m.Metrics.Register( - vmNamespace, - vmGatherer, - ) - if err != nil { - return nil, err - } - m.vmGatherer[vmID] = vmGatherer + if ok { + return vmGatherer, nil } - chainReg := metrics.NewPrefixGatherer() - err := vmGatherer.Register( - chainAlias, - chainReg, + vmName := constants.VMName(vmID) + vmNamespace := metric.AppendNamespace(constants.PlatformName, vmName) + vmGatherer = metrics.NewLabelGatherer(ChainLabel) + err := m.Metrics.Register( + vmNamespace, + vmGatherer, ) - return chainReg, err + if err != nil { + return nil, err + } + m.vmGatherer[vmID] = vmGatherer + return vmGatherer, nil } From 21cdc32e2b6508a09ff390fdee7f7b5bf90620fe Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 10:28:07 -0400 Subject: [PATCH 16/20] Update P-chain state staker tests --- vms/platformvm/state/state_test.go | 1203 ++++++++++++---------------- 1 file changed, 514 insertions(+), 689 deletions(-) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 6204540bd61..e9096af13a1 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -5,7 +5,6 @@ package state import ( "context" - "fmt" "math" "math/rand" "sync" @@ -28,6 +27,8 @@ import ( "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/maybe" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/units" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -109,620 +110,387 @@ func TestStateSyncGenesis(t *testing.T) { ) } -// Whenever we store a staker, a whole bunch a data structures are updated -// This test is meant to capture which updates are carried out -func TestPersistStakers(t *testing.T) { - tests := map[string]struct { - // Insert or delete a staker to state and store it - storeStaker func(*require.Assertions, ids.ID /*=subnetID*/, *state) *Staker +// Whenever we add or remove a staker, a number of on-disk data structures +// should be updated. +// +// This test verifies that the on-disk data structures are updated as expected. +func TestState_writeStakers(t *testing.T) { + const ( + primaryValidatorDuration = 28 * 24 * time.Hour + primaryDelegatorDuration = 14 * 24 * time.Hour + subnetValidatorDuration = 21 * 24 * time.Hour + + primaryValidatorReward = iota + primaryDelegatorReward + subnetValidatorReward + ) + var ( + primaryValidatorStartTime = time.Now().Truncate(time.Second) + primaryValidatorEndTime = primaryValidatorStartTime.Add(primaryValidatorDuration) + primaryValidatorEndTimeUnix = uint64(primaryValidatorEndTime.Unix()) + + primaryDelegatorStartTime = primaryValidatorStartTime + primaryDelegatorEndTime = primaryDelegatorStartTime.Add(primaryDelegatorDuration) + primaryDelegatorEndTimeUnix = uint64(primaryDelegatorEndTime.Unix()) + + subnetValidatorStartTime = primaryValidatorStartTime + subnetValidatorEndTime = subnetValidatorStartTime.Add(subnetValidatorDuration) + subnetValidatorEndTimeUnix = uint64(subnetValidatorEndTime.Unix()) + + primaryValidatorData = txs.Validator{ + NodeID: ids.GenerateTestNodeID(), + End: primaryValidatorEndTimeUnix, + Wght: 1234, + } + primaryDelegatorData = txs.Validator{ + NodeID: primaryValidatorData.NodeID, + End: primaryDelegatorEndTimeUnix, + Wght: 6789, + } + subnetValidatorData = txs.Validator{ + NodeID: primaryValidatorData.NodeID, + End: subnetValidatorEndTimeUnix, + Wght: 9876, + } - // Check that the staker is duly stored/removed in P-chain state - checkStakerInState func(*require.Assertions, *state, *Staker) + subnetID = ids.GenerateTestID() + ) - // Check whether validators are duly reported in the validator set, - // with the right weight and showing the BLS key - checkValidatorsSet func(*require.Assertions, *state, *Staker) + unsignedAddPrimaryNetworkValidator := createPermissionlessValidatorTx(t, constants.PrimaryNetworkID, primaryValidatorData) + addPrimaryNetworkValidator := &txs.Tx{Unsigned: unsignedAddPrimaryNetworkValidator} + require.NoError(t, addPrimaryNetworkValidator.Initialize(txs.Codec)) - // Check that node duly track stakers uptimes - checkValidatorUptimes func(*require.Assertions, *state, *Staker) + primaryNetworkPendingValidatorStaker, err := NewPendingStaker( + addPrimaryNetworkValidator.ID(), + unsignedAddPrimaryNetworkValidator, + ) + require.NoError(t, err) - // Check whether weight/bls keys diffs are duly stored - checkDiffs func(*require.Assertions, *state, *Staker, uint64) - }{ - "add current validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(endTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - ) + primaryNetworkCurrentValidatorStaker, err := NewCurrentStaker( + addPrimaryNetworkValidator.ID(), + unsignedAddPrimaryNetworkValidator, + primaryValidatorStartTime, + primaryValidatorReward, + ) + require.NoError(t, err) - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) + unsignedAddPrimaryNetworkDelegator := createPermissionlessDelegatorTx(constants.PrimaryNetworkID, primaryDelegatorData) + addPrimaryNetworkDelegator := &txs.Tx{Unsigned: unsignedAddPrimaryNetworkDelegator} + require.NoError(t, addPrimaryNetworkDelegator.Initialize(txs.Codec)) - staker, err := NewCurrentStaker( - addPermValTx.ID(), - utx, - time.Unix(startTime, 0), - validatorReward, - ) - r.NoError(err) + primaryNetworkPendingDelegatorStaker, err := NewPendingStaker( + addPrimaryNetworkDelegator.ID(), + unsignedAddPrimaryNetworkDelegator, + ) + require.NoError(t, err) - r.NoError(s.PutCurrentValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - retrievedStaker, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.Equal(staker, retrievedStaker) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - r.Equal( - &validators.GetValidatorOutput{ - NodeID: staker.NodeID, - PublicKey: staker.PublicKey, - Weight: staker.Weight, - }, - valsMap[staker.NodeID], - ) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - upDuration, lastUpdated, err := s.GetUptime(staker.NodeID) - if staker.SubnetID != constants.PrimaryNetworkID { - // only primary network validators have uptimes - r.ErrorIs(err, database.ErrNotFound) - } else { - r.NoError(err) - r.Equal(upDuration, time.Duration(0)) - r.Equal(lastUpdated, staker.StartTime) - } - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: false, - Amount: staker.Weight, - }, weightDiff) - - blsDiffBytes, err := s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - if staker.SubnetID == constants.PrimaryNetworkID { - r.NoError(err) - r.Nil(blsDiffBytes) - } else { - r.ErrorIs(err, database.ErrNotFound) - } - }, - }, - "add current delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert the delegator and its validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(valEndTime), - Wght: 1234, - } - validatorReward uint64 = 5678 + primaryNetworkCurrentDelegatorStaker, err := NewCurrentStaker( + addPrimaryNetworkDelegator.ID(), + unsignedAddPrimaryNetworkDelegator, + primaryDelegatorStartTime, + primaryDelegatorReward, + ) + require.NoError(t, err) - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - delegatorReward uint64 = 5432 - ) + unsignedAddSubnetValidator := createPermissionlessValidatorTx(t, subnetID, subnetValidatorData) + addSubnetValidator := &txs.Tx{Unsigned: unsignedAddSubnetValidator} + require.NoError(t, addSubnetValidator.Initialize(txs.Codec)) - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) + subnetCurrentValidatorStaker, err := NewCurrentStaker( + addSubnetValidator.ID(), + unsignedAddSubnetValidator, + subnetValidatorStartTime, + subnetValidatorReward, + ) + require.NoError(t, err) - val, err := NewCurrentStaker( - addPermValTx.ID(), - utxVal, - time.Unix(valStartTime, 0), - validatorReward, - ) - r.NoError(err) + tests := map[string]struct { + initialStakers []*Staker + initialTxs []*txs.Tx - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) + // Staker to insert or remove + staker *Staker + addStakerTx *txs.Tx // If tx is nil, the staker is being removed - del, err := NewCurrentStaker( - addPermDelTx.ID(), - utxDel, - time.Unix(delStartTime, 0), - delegatorReward, - ) - r.NoError(err) + // Check that the staker is duly stored/removed in P-chain state + expectedCurrentValidator *Staker + expectedPendingValidator *Staker + expectedCurrentDelegators []*Staker + expectedPendingDelegators []*Staker - r.NoError(s.PutCurrentValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + // Check that the validator entry has been set correctly in the + // in-memory validator set. + expectedValidatorSetOutput *validators.GetValidatorOutput - s.PutCurrentDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetCurrentDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.True(delIt.Next()) - retrievedDelegator := delIt.Value() - r.False(delIt.Next()) - delIt.Release() - r.Equal(staker, retrievedDelegator) + // Check whether weight/bls keys diffs are duly stored + expectedWeightDiff *ValidatorWeightDiff + expectedPublicKeyDiff maybe.Maybe[*bls.PublicKey] + }{ + "add current primary network validator": { + staker: primaryNetworkCurrentValidatorStaker, + addStakerTx: addPrimaryNetworkValidator, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentValidatorStaker.Weight, }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - val, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - valOut := valsMap[staker.NodeID] - r.Equal(valOut.NodeID, staker.NodeID) - r.Equal(valOut.Weight, val.Weight+staker.Weight) + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), + }, + "add current primary network delegator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkCurrentDelegatorStaker, + addStakerTx: addPrimaryNetworkDelegator, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedCurrentDelegators: []*Staker{primaryNetworkCurrentDelegatorStaker}, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentDelegatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentDelegatorStaker.Weight + primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // validator's weight must increase of delegator's weight amount - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: false, - Amount: staker.Weight, - }, weightDiff) + }, + "add pending primary network validator": { + staker: primaryNetworkPendingValidatorStaker, + addStakerTx: addPrimaryNetworkValidator, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, + }, + "add pending primary network delegator": { + initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkPendingDelegatorStaker, + addStakerTx: addPrimaryNetworkDelegator, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, + expectedPendingDelegators: []*Staker{primaryNetworkPendingDelegatorStaker}, + }, + "add current subnet validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: subnetCurrentValidatorStaker, + addStakerTx: addSubnetValidator, + expectedCurrentValidator: subnetCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: subnetCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: subnetCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: false, + Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, - "add pending validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(startTime), - End: uint64(endTime), - Wght: 1234, - } - ) - - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - staker, err := NewPendingStaker( - addPermValTx.ID(), - utx, - ) - r.NoError(err) - - r.NoError(s.PutPendingValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - return staker + "delete current primary network validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkCurrentValidatorStaker, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentValidatorStaker.Weight, }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - retrievedStaker, err := s.GetPendingValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.Equal(staker, retrievedStaker) + expectedPublicKeyDiff: maybe.Some(primaryNetworkCurrentValidatorStaker.PublicKey), + }, + "delete current primary network delegator": { + initialStakers: []*Staker{ + primaryNetworkCurrentValidatorStaker, + primaryNetworkCurrentDelegatorStaker, + }, + initialTxs: []*txs.Tx{ + addPrimaryNetworkValidator, + addPrimaryNetworkDelegator, + }, + staker: primaryNetworkCurrentDelegatorStaker, + expectedCurrentValidator: primaryNetworkCurrentValidatorStaker, + expectedValidatorSetOutput: &validators.GetValidatorOutput{ + NodeID: primaryNetworkCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: primaryNetworkCurrentValidatorStaker.Weight, + }, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: primaryNetworkCurrentDelegatorStaker.Weight, }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - // pending validators are not showed in validators set - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) + }, + "delete pending primary network validator": { + initialStakers: []*Staker{primaryNetworkPendingValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator}, + staker: primaryNetworkPendingValidatorStaker, + }, + "delete pending primary network delegator": { + initialStakers: []*Staker{ + primaryNetworkPendingValidatorStaker, + primaryNetworkPendingDelegatorStaker, }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - // pending validators uptime is not tracked - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) + initialTxs: []*txs.Tx{ + addPrimaryNetworkValidator, + addPrimaryNetworkDelegator, }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // pending validators weight diff and bls diffs are not stored - _, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) - - _, err = s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) + staker: primaryNetworkPendingDelegatorStaker, + expectedPendingValidator: primaryNetworkPendingValidatorStaker, + }, + "delete current subnet validator": { + initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker, subnetCurrentValidatorStaker}, + initialTxs: []*txs.Tx{addPrimaryNetworkValidator, addSubnetValidator}, + staker: subnetCurrentValidatorStaker, + expectedWeightDiff: &ValidatorWeightDiff{ + Decrease: true, + Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), }, - "add pending delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert the delegator and its validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(valStartTime), - End: uint64(valEndTime), - Wght: 1234, - } + } - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - Start: uint64(delStartTime), - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - ) + for name, test := range tests { + t.Run(name, func(t *testing.T) { + require := require.New(t) - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) + db := memdb.New() + state := newTestState(t, db) + + addOrDeleteStaker := func(staker *Staker, add bool) { + if add { + switch { + case staker.Priority.IsCurrentValidator(): + require.NoError(state.PutCurrentValidator(staker)) + case staker.Priority.IsPendingValidator(): + require.NoError(state.PutPendingValidator(staker)) + case staker.Priority.IsCurrentDelegator(): + state.PutCurrentDelegator(staker) + case staker.Priority.IsPendingDelegator(): + state.PutPendingDelegator(staker) + } + } else { + switch { + case staker.Priority.IsCurrentValidator(): + state.DeleteCurrentValidator(staker) + case staker.Priority.IsPendingValidator(): + state.DeletePendingValidator(staker) + case staker.Priority.IsCurrentDelegator(): + state.DeleteCurrentDelegator(staker) + case staker.Priority.IsPendingDelegator(): + state.DeletePendingDelegator(staker) + } + } + } - val, err := NewPendingStaker(addPermValTx.ID(), utxVal) - r.NoError(err) + // create and store the initial stakers + for _, staker := range test.initialStakers { + addOrDeleteStaker(staker, true) + } + for _, tx := range test.initialTxs { + state.AddTx(tx, status.Committed) + } - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) + state.SetHeight(0) + require.NoError(state.Commit()) - del, err := NewPendingStaker(addPermDelTx.ID(), utxDel) - r.NoError(err) + // create and store the staker under test + addOrDeleteStaker(test.staker, test.addStakerTx != nil) + if test.addStakerTx != nil { + state.AddTx(test.addStakerTx, status.Committed) + } - r.NoError(s.PutPendingValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + state.SetHeight(1) + require.NoError(state.Commit()) - s.PutPendingDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + // Perform the checks once immediately after committing to the + // state, and once after re-loading the state from disk. + for i := 0; i < 2; i++ { + currentValidator, err := state.GetCurrentValidator(test.staker.SubnetID, test.staker.NodeID) + if test.expectedCurrentValidator == nil { + require.ErrorIs(err, database.ErrNotFound) - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetPendingDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.True(delIt.Next()) - retrievedDelegator := delIt.Value() - r.False(delIt.Next()) - delIt.Release() - r.Equal(staker, retrievedDelegator) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(*require.Assertions, *state, *Staker, uint64) {}, - }, - "delete current validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // add them remove the validator - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(endTime), - Wght: 1234, + if test.staker.SubnetID == constants.PrimaryNetworkID { + // Uptimes are only considered for primary network validators + _, _, err := state.GetUptime(test.staker.NodeID) + require.ErrorIs(err, database.ErrNotFound) } - validatorReward uint64 = 5678 - ) - - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - staker, err := NewCurrentStaker( - addPermValTx.ID(), - utx, - time.Unix(startTime, 0), - validatorReward, - ) - r.NoError(err) - - r.NoError(s.PutCurrentValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) + } else { + require.NoError(err) + require.Equal(test.expectedCurrentValidator, currentValidator) + + if test.staker.SubnetID == constants.PrimaryNetworkID { + // Uptimes are only considered for primary network validators + upDuration, lastUpdated, err := state.GetUptime(currentValidator.NodeID) + require.NoError(err) + require.Zero(upDuration) + require.Equal(currentValidator.StartTime, lastUpdated) + } + } - s.DeleteCurrentValidator(staker) - r.NoError(s.Commit()) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - _, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - // deleted validators are not showed in the validators set anymore - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - // uptimes of delete validators are dropped - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: true, - Amount: staker.Weight, - }, weightDiff) - - blsDiffBytes, err := s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - if staker.SubnetID == constants.PrimaryNetworkID { - r.NoError(err) - r.Equal(bls.PublicKeyFromValidUncompressedBytes(blsDiffBytes), staker.PublicKey) + pendingValidator, err := state.GetPendingValidator(test.staker.SubnetID, test.staker.NodeID) + if test.expectedPendingValidator == nil { + require.ErrorIs(err, database.ErrNotFound) } else { - r.ErrorIs(err, database.ErrNotFound) + require.NoError(err) + require.Equal(test.expectedPendingValidator, pendingValidator) } - }, - }, - "delete current delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert validator and delegator, then remove the delegator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - End: uint64(valEndTime), - Wght: 1234, - } - validatorReward uint64 = 5678 - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, - } - delegatorReward uint64 = 5432 + it, err := state.GetCurrentDelegatorIterator(test.staker.SubnetID, test.staker.NodeID) + require.NoError(err) + require.Equal( + test.expectedCurrentDelegators, + iterator.ToSlice(it), ) - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - val, err := NewCurrentStaker( - addPermValTx.ID(), - utxVal, - time.Unix(valStartTime, 0), - validatorReward, + it, err = state.GetPendingDelegatorIterator(test.staker.SubnetID, test.staker.NodeID) + require.NoError(err) + require.Equal( + test.expectedPendingDelegators, + iterator.ToSlice(it), ) - r.NoError(err) - - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) - del, err := NewCurrentStaker( - addPermDelTx.ID(), - utxDel, - time.Unix(delStartTime, 0), - delegatorReward, + require.Equal( + test.expectedValidatorSetOutput, + state.validators.GetMap(test.staker.SubnetID)[test.staker.NodeID], ) - r.NoError(err) - - r.NoError(s.PutCurrentValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - s.PutCurrentDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - s.DeleteCurrentDelegator(del) - r.NoError(s.Commit()) - - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetCurrentDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.False(delIt.Next()) - delIt.Release() - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - val, err := s.GetCurrentValidator(staker.SubnetID, staker.NodeID) - r.NoError(err) - - valsMap := s.validators.GetMap(staker.SubnetID) - r.Contains(valsMap, staker.NodeID) - valOut := valsMap[staker.NodeID] - r.Equal(valOut.NodeID, staker.NodeID) - r.Equal(valOut.Weight, val.Weight) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - // validator's weight must decrease of delegator's weight amount - weightDiffBytes, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.NoError(err) - weightDiff, err := unmarshalWeightDiff(weightDiffBytes) - r.NoError(err) - r.Equal(&ValidatorWeightDiff{ - Decrease: true, - Amount: staker.Weight, - }, weightDiff) - }, - }, - "delete pending validator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - var ( - startTime = time.Now().Unix() - endTime = time.Now().Add(14 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(startTime), - End: uint64(endTime), - Wght: 1234, - } - ) - - utx := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utx} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - staker, err := NewPendingStaker( - addPermValTx.ID(), - utx, - ) - r.NoError(err) - - r.NoError(s.PutPendingValidator(staker)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - s.DeletePendingValidator(staker) - r.NoError(s.Commit()) + diffKey := marshalDiffKey(test.staker.SubnetID, 1, test.staker.NodeID) + weightDiffBytes, err := state.validatorWeightDiffsDB.Get(diffKey) + if test.expectedWeightDiff == nil { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) - return staker - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - _, err := s.GetPendingValidator(staker.SubnetID, staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(r *require.Assertions, s *state, staker *Staker) { - _, _, err := s.GetUptime(staker.NodeID) - r.ErrorIs(err, database.ErrNotFound) - }, - checkDiffs: func(r *require.Assertions, s *state, staker *Staker, height uint64) { - _, err := s.validatorWeightDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) + weightDiff, err := unmarshalWeightDiff(weightDiffBytes) + require.NoError(err) + require.Equal(test.expectedWeightDiff, weightDiff) + } - _, err = s.validatorPublicKeyDiffsDB.Get(marshalDiffKey(staker.SubnetID, height, staker.NodeID)) - r.ErrorIs(err, database.ErrNotFound) - }, - }, - "delete pending delegator": { - storeStaker: func(r *require.Assertions, subnetID ids.ID, s *state) *Staker { - // insert validator and delegator the remove the validator - var ( - valStartTime = time.Now().Truncate(time.Second).Unix() - delStartTime = time.Unix(valStartTime, 0).Add(time.Hour).Unix() - delEndTime = time.Unix(delStartTime, 0).Add(30 * 24 * time.Hour).Unix() - valEndTime = time.Unix(valStartTime, 0).Add(365 * 24 * time.Hour).Unix() - - validatorsData = txs.Validator{ - NodeID: ids.GenerateTestNodeID(), - Start: uint64(valStartTime), - End: uint64(valEndTime), - Wght: 1234, - } + publicKeyDiffBytes, err := state.validatorPublicKeyDiffsDB.Get(diffKey) + if test.expectedPublicKeyDiff.IsNothing() { + require.ErrorIs(err, database.ErrNotFound) + } else { + require.NoError(err) - delegatorData = txs.Validator{ - NodeID: validatorsData.NodeID, - Start: uint64(delStartTime), - End: uint64(delEndTime), - Wght: validatorsData.Wght / 2, + expectedPublicKeyDiff := test.expectedPublicKeyDiff.Value() + if expectedPublicKeyDiff != nil { + require.Equal(expectedPublicKeyDiff, bls.PublicKeyFromValidUncompressedBytes(publicKeyDiffBytes)) + } else { + require.Empty(publicKeyDiffBytes) } - ) - - utxVal := createPermissionlessValidatorTx(r, subnetID, validatorsData) - addPermValTx := &txs.Tx{Unsigned: utxVal} - r.NoError(addPermValTx.Initialize(txs.Codec)) - - val, err := NewPendingStaker(addPermValTx.ID(), utxVal) - r.NoError(err) - - utxDel := createPermissionlessDelegatorTx(subnetID, delegatorData) - addPermDelTx := &txs.Tx{Unsigned: utxDel} - r.NoError(addPermDelTx.Initialize(txs.Codec)) - - del, err := NewPendingStaker(addPermDelTx.ID(), utxDel) - r.NoError(err) - - r.NoError(s.PutPendingValidator(val)) - s.AddTx(addPermValTx, status.Committed) // this is currently needed to reload the staker - - s.PutPendingDelegator(del) - s.AddTx(addPermDelTx, status.Committed) // this is currently needed to reload the staker - r.NoError(s.Commit()) - - s.DeletePendingDelegator(del) - r.NoError(s.Commit()) - return del - }, - checkStakerInState: func(r *require.Assertions, s *state, staker *Staker) { - delIt, err := s.GetPendingDelegatorIterator(staker.SubnetID, staker.NodeID) - r.NoError(err) - r.False(delIt.Next()) - delIt.Release() - }, - checkValidatorsSet: func(r *require.Assertions, s *state, staker *Staker) { - valsMap := s.validators.GetMap(staker.SubnetID) - r.NotContains(valsMap, staker.NodeID) - }, - checkValidatorUptimes: func(*require.Assertions, *state, *Staker) {}, - checkDiffs: func(*require.Assertions, *state, *Staker, uint64) {}, - }, - } - - subnetIDs := []ids.ID{constants.PrimaryNetworkID, ids.GenerateTestID()} - for _, subnetID := range subnetIDs { - for name, test := range tests { - t.Run(fmt.Sprintf("%s - subnetID %s", name, subnetID), func(t *testing.T) { - require := require.New(t) - - db := memdb.New() - state := newTestState(t, db) - - // create and store the staker - staker := test.storeStaker(require, subnetID, state) - - // check all relevant data are stored - test.checkStakerInState(require, state, staker) - test.checkValidatorsSet(require, state, staker) - test.checkValidatorUptimes(require, state, staker) - test.checkDiffs(require, state, staker, 0 /*height*/) - - // rebuild the state - rebuiltState := newTestState(t, db) + } - // check again that all relevant data are still available in rebuilt state - test.checkStakerInState(require, rebuiltState, staker) - test.checkValidatorsSet(require, rebuiltState, staker) - test.checkValidatorUptimes(require, rebuiltState, staker) - test.checkDiffs(require, rebuiltState, staker, 0 /*height*/) - }) - } + // re-load the state from disk for the second iteration + state = newTestState(t, db) + } + }) } } -func createPermissionlessValidatorTx(r *require.Assertions, subnetID ids.ID, validatorsData txs.Validator) *txs.AddPermissionlessValidatorTx { +func createPermissionlessValidatorTx(t testing.TB, subnetID ids.ID, validatorsData txs.Validator) *txs.AddPermissionlessValidatorTx { var sig signer.Signer = &signer.Empty{} if subnetID == constants.PrimaryNetworkID { sk, err := bls.NewSecretKey() - r.NoError(err) + require.NoError(t, err) sig = signer.NewProofOfPossession(sk) } @@ -988,43 +756,49 @@ func TestValidatorWeightDiff(t *testing.T) { } } -// Tests PutCurrentValidator, DeleteCurrentValidator, GetCurrentValidator, -// ApplyValidatorWeightDiffs, ApplyValidatorPublicKeyDiffs -func TestStateAddRemoveValidator(t *testing.T) { +func TestState_ApplyValidatorDiffs(t *testing.T) { require := require.New(t) state := newTestState(t, memdb.New()) var ( - numNodes = 3 - subnetID = ids.GenerateTestID() - startTime = time.Now() - endTime = startTime.Add(24 * time.Hour) - stakers = make([]Staker, numNodes) + numNodes = 5 + subnetID = ids.GenerateTestID() + startTime = time.Now() + endTime = startTime.Add(24 * time.Hour) + primaryStakers = make([]Staker, numNodes) + subnetStakers = make([]Staker, numNodes) ) - for i := 0; i < numNodes; i++ { - stakers[i] = Staker{ + for i := range primaryStakers { + sk, err := bls.NewSecretKey() + require.NoError(err) + + primaryStakers[i] = Staker{ TxID: ids.GenerateTestID(), NodeID: ids.GenerateTestNodeID(), + PublicKey: bls.PublicFromSecretKey(sk), + SubnetID: constants.PrimaryNetworkID, Weight: uint64(i + 1), StartTime: startTime.Add(time.Duration(i) * time.Second), EndTime: endTime.Add(time.Duration(i) * time.Second), PotentialReward: uint64(i + 1), } - if i%2 == 0 { - stakers[i].SubnetID = subnetID - } else { - sk, err := bls.NewSecretKey() - require.NoError(err) - stakers[i].PublicKey = bls.PublicFromSecretKey(sk) - stakers[i].SubnetID = constants.PrimaryNetworkID + } + for i, primaryStaker := range primaryStakers { + subnetStakers[i] = Staker{ + TxID: ids.GenerateTestID(), + NodeID: primaryStaker.NodeID, + PublicKey: nil, // Key is inherited from the primary network + SubnetID: subnetID, + Weight: uint64(i + 1), + StartTime: primaryStaker.StartTime, + EndTime: primaryStaker.EndTime, + PotentialReward: uint64(i + 1), } } type diff struct { addedValidators []Staker - addedDelegators []Staker - removedDelegators []Staker removedValidators []Staker expectedPrimaryValidatorSet map[ids.NodeID]*validators.GetValidatorOutput @@ -1037,101 +811,172 @@ func TestStateAddRemoveValidator(t *testing.T) { expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Add a subnet validator - addedValidators: []Staker{stakers[0]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[0].NodeID: { - NodeID: stakers[0].NodeID, - Weight: stakers[0].Weight, + // Add primary validator 0 + addedValidators: []Staker{primaryStakers[0]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, }, }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Remove a subnet validator - removedValidators: []Staker{stakers[0]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + // Add subnet validator 0 + addedValidators: []Staker{subnetStakers[0]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, + }, + }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + subnetStakers[0].NodeID: { + NodeID: subnetStakers[0].NodeID, + Weight: subnetStakers[0].Weight, + }, + }, }, - { // Add a primary network validator - addedValidators: []Staker{stakers[1]}, + { + // Remove subnet validator 0 + removedValidators: []Staker{subnetStakers[0]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[0].NodeID: { + NodeID: primaryStakers[0].NodeID, + PublicKey: primaryStakers[0].PublicKey, + Weight: primaryStakers[0].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, { - // Do nothing + // Add primary network validator 1, and subnet validator 1 + addedValidators: []Staker{primaryStakers[1], subnetStakers[1]}, + // Remove primary network validator 0, and subnet validator 1 + removedValidators: []Staker{primaryStakers[0], subnetStakers[1]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[1].NodeID: { + NodeID: primaryStakers[1].NodeID, + PublicKey: primaryStakers[1].PublicKey, + Weight: primaryStakers[1].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, - { // Remove a primary network validator - removedValidators: []Staker{stakers[1]}, - expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, - expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + { + // Add primary network validator 2, and subnet validator 2 + addedValidators: []Staker{primaryStakers[2], subnetStakers[2]}, + // Remove primary network validator 1 + removedValidators: []Staker{primaryStakers[1]}, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + primaryStakers[2].NodeID: { + NodeID: primaryStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: primaryStakers[2].Weight, + }, + }, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ + subnetStakers[2].NodeID: { + NodeID: subnetStakers[2].NodeID, + Weight: subnetStakers[2].Weight, + }, + }, }, { - // Add 2 subnet validators and a primary network validator - addedValidators: []Staker{stakers[0], stakers[1], stakers[2]}, + // Add primary network and subnet validators 3 & 4 + addedValidators: []Staker{primaryStakers[3], primaryStakers[4], subnetStakers[3], subnetStakers[4]}, expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[1].NodeID: { - NodeID: stakers[1].NodeID, - PublicKey: stakers[1].PublicKey, - Weight: stakers[1].Weight, + primaryStakers[2].NodeID: { + NodeID: primaryStakers[2].NodeID, + PublicKey: primaryStakers[2].PublicKey, + Weight: primaryStakers[2].Weight, + }, + primaryStakers[3].NodeID: { + NodeID: primaryStakers[3].NodeID, + PublicKey: primaryStakers[3].PublicKey, + Weight: primaryStakers[3].Weight, + }, + primaryStakers[4].NodeID: { + NodeID: primaryStakers[4].NodeID, + PublicKey: primaryStakers[4].PublicKey, + Weight: primaryStakers[4].Weight, }, }, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{ - stakers[0].NodeID: { - NodeID: stakers[0].NodeID, - Weight: stakers[0].Weight, + subnetStakers[2].NodeID: { + NodeID: subnetStakers[2].NodeID, + Weight: subnetStakers[2].Weight, + }, + subnetStakers[3].NodeID: { + NodeID: subnetStakers[3].NodeID, + Weight: subnetStakers[3].Weight, }, - stakers[2].NodeID: { - NodeID: stakers[2].NodeID, - Weight: stakers[2].Weight, + subnetStakers[4].NodeID: { + NodeID: subnetStakers[4].NodeID, + Weight: subnetStakers[4].Weight, }, }, }, { - // Remove 2 subnet validators and a primary network validator. - removedValidators: []Staker{stakers[0], stakers[1], stakers[2]}, + // Remove primary network and subnet validators 2 & 3 & 4 + removedValidators: []Staker{ + primaryStakers[2], primaryStakers[3], primaryStakers[4], + subnetStakers[2], subnetStakers[3], subnetStakers[4], + }, + expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, + }, + { + // Do nothing expectedPrimaryValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, expectedSubnetValidatorSet: map[ids.NodeID]*validators.GetValidatorOutput{}, }, } for currentIndex, diff := range diffs { - for _, added := range diff.addedValidators { - added := added - require.NoError(state.PutCurrentValidator(&added)) - } - for _, added := range diff.addedDelegators { - added := added - state.PutCurrentDelegator(&added) + d, err := NewDiffOn(state) + require.NoError(err) + + type subnetIDNodeID struct { + subnetID ids.ID + nodeID ids.NodeID } - for _, removed := range diff.removedDelegators { - removed := removed - state.DeleteCurrentDelegator(&removed) + var expectedValidators set.Set[subnetIDNodeID] + for _, added := range diff.addedValidators { + require.NoError(d.PutCurrentValidator(&added)) + + expectedValidators.Add(subnetIDNodeID{ + subnetID: added.SubnetID, + nodeID: added.NodeID, + }) } for _, removed := range diff.removedValidators { - removed := removed - state.DeleteCurrentValidator(&removed) + d.DeleteCurrentValidator(&removed) + + expectedValidators.Remove(subnetIDNodeID{ + subnetID: removed.SubnetID, + nodeID: removed.NodeID, + }) } + require.NoError(d.Apply(state)) + currentHeight := uint64(currentIndex + 1) state.SetHeight(currentHeight) require.NoError(state.Commit()) + // Verify that the current state is as expected. for _, added := range diff.addedValidators { + subnetNodeID := subnetIDNodeID{ + subnetID: added.SubnetID, + nodeID: added.NodeID, + } + if !expectedValidators.Contains(subnetNodeID) { + continue + } + gotValidator, err := state.GetCurrentValidator(added.SubnetID, added.NodeID) require.NoError(err) require.Equal(added, *gotValidator) @@ -1142,37 +987,47 @@ func TestStateAddRemoveValidator(t *testing.T) { require.ErrorIs(err, database.ErrNotFound) } + primaryValidatorSet := state.validators.GetMap(constants.PrimaryNetworkID) + delete(primaryValidatorSet, defaultValidatorNodeID) // Ignore the genesis validator + require.Equal(diff.expectedPrimaryValidatorSet, primaryValidatorSet) + + require.Equal(diff.expectedSubnetValidatorSet, state.validators.GetMap(subnetID)) + + // Verify that applying diffs against the current state results in the + // expected state. for i := 0; i < currentIndex; i++ { prevDiff := diffs[i] prevHeight := uint64(i + 1) - primaryValidatorSet := copyValidatorSet(diff.expectedPrimaryValidatorSet) - require.NoError(state.ApplyValidatorWeightDiffs( - context.Background(), - primaryValidatorSet, - currentHeight, - prevHeight+1, - constants.PrimaryNetworkID, - )) - requireEqualWeightsValidatorSet(require, prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) - - require.NoError(state.ApplyValidatorPublicKeyDiffs( - context.Background(), - primaryValidatorSet, - currentHeight, - prevHeight+1, - )) - requireEqualPublicKeysValidatorSet(require, prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) - - subnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) - require.NoError(state.ApplyValidatorWeightDiffs( - context.Background(), - subnetValidatorSet, - currentHeight, - prevHeight+1, - subnetID, - )) - requireEqualWeightsValidatorSet(require, prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) + { + primaryValidatorSet := copyValidatorSet(diff.expectedPrimaryValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + primaryValidatorSet, + currentHeight, + prevHeight+1, + constants.PrimaryNetworkID, + )) + require.NoError(state.ApplyValidatorPublicKeyDiffs( + context.Background(), + primaryValidatorSet, + currentHeight, + prevHeight+1, + )) + require.Equal(prevDiff.expectedPrimaryValidatorSet, primaryValidatorSet) + } + + { + subnetValidatorSet := copyValidatorSet(diff.expectedSubnetValidatorSet) + require.NoError(state.ApplyValidatorWeightDiffs( + context.Background(), + subnetValidatorSet, + currentHeight, + prevHeight+1, + subnetID, + )) + require.Equal(prevDiff.expectedSubnetValidatorSet, subnetValidatorSet) + } } } } @@ -1188,36 +1043,6 @@ func copyValidatorSet( return result } -func requireEqualWeightsValidatorSet( - require *require.Assertions, - expected map[ids.NodeID]*validators.GetValidatorOutput, - actual map[ids.NodeID]*validators.GetValidatorOutput, -) { - require.Len(actual, len(expected)) - for nodeID, expectedVdr := range expected { - require.Contains(actual, nodeID) - - actualVdr := actual[nodeID] - require.Equal(expectedVdr.NodeID, actualVdr.NodeID) - require.Equal(expectedVdr.Weight, actualVdr.Weight) - } -} - -func requireEqualPublicKeysValidatorSet( - require *require.Assertions, - expected map[ids.NodeID]*validators.GetValidatorOutput, - actual map[ids.NodeID]*validators.GetValidatorOutput, -) { - require.Len(actual, len(expected)) - for nodeID, expectedVdr := range expected { - require.Contains(actual, nodeID) - - actualVdr := actual[nodeID] - require.Equal(expectedVdr.NodeID, actualVdr.NodeID) - require.Equal(expectedVdr.PublicKey, actualVdr.PublicKey) - } -} - func TestParsedStateBlock(t *testing.T) { var ( require = require.New(t) From 91a7cd7f30798db77b7c7049d9fe662b81914e7c Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 10:43:46 -0400 Subject: [PATCH 17/20] fix test --- vms/platformvm/state/state_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index e9096af13a1..522e5a91240 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -278,15 +278,13 @@ func TestState_writeStakers(t *testing.T) { addStakerTx: addSubnetValidator, expectedCurrentValidator: subnetCurrentValidatorStaker, expectedValidatorSetOutput: &validators.GetValidatorOutput{ - NodeID: subnetCurrentValidatorStaker.NodeID, - PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, - Weight: subnetCurrentValidatorStaker.Weight, + NodeID: subnetCurrentValidatorStaker.NodeID, + Weight: subnetCurrentValidatorStaker.Weight, }, expectedWeightDiff: &ValidatorWeightDiff{ Decrease: false, Amount: subnetCurrentValidatorStaker.Weight, }, - expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, "delete current primary network validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, @@ -344,7 +342,6 @@ func TestState_writeStakers(t *testing.T) { Decrease: true, Amount: subnetCurrentValidatorStaker.Weight, }, - expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), }, } From c848feef30038c116ad8e7389c54abf92c8d69d5 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 10:44:54 -0400 Subject: [PATCH 18/20] fix test --- vms/platformvm/state/state_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index c6f6c0d52ec..f3c952ff90d 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -278,13 +278,15 @@ func TestState_writeStakers(t *testing.T) { addStakerTx: addSubnetValidator, expectedCurrentValidator: subnetCurrentValidatorStaker, expectedValidatorSetOutput: &validators.GetValidatorOutput{ - NodeID: subnetCurrentValidatorStaker.NodeID, - Weight: subnetCurrentValidatorStaker.Weight, + NodeID: subnetCurrentValidatorStaker.NodeID, + PublicKey: primaryNetworkCurrentValidatorStaker.PublicKey, + Weight: subnetCurrentValidatorStaker.Weight, }, expectedWeightDiff: &ValidatorWeightDiff{ Decrease: false, Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](nil), }, "delete current primary network validator": { initialStakers: []*Staker{primaryNetworkCurrentValidatorStaker}, @@ -342,6 +344,7 @@ func TestState_writeStakers(t *testing.T) { Decrease: true, Amount: subnetCurrentValidatorStaker.Weight, }, + expectedPublicKeyDiff: maybe.Some[*bls.PublicKey](primaryNetworkCurrentValidatorStaker.PublicKey), }, } From 79c9b40074f82ae33ebf0e4ac8ac2191c3e81144 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 10:49:15 -0400 Subject: [PATCH 19/20] nit --- vms/platformvm/state/state.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index cfb5888ecfa..c70547265cf 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -2079,12 +2079,11 @@ func (s *state) writeCurrentStakers(updateValidators bool, height uint64, codecV } // TODO: Move validator set management out of the state package - // - // Attempt to update the stake metrics if !updateValidators { return nil } + // Update the stake metrics totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) if err != nil { return fmt.Errorf("failed to get total weight of primary network: %w", err) From ed2fcfd4d058ba134ee4a22a55c40fbe56c6b972 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Thu, 24 Oct 2024 11:16:03 -0400 Subject: [PATCH 20/20] simplify test --- tests/e2e/p/permissionless_layer_one.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/e2e/p/permissionless_layer_one.go b/tests/e2e/p/permissionless_layer_one.go index 44cb13b2445..de4aaa11a08 100644 --- a/tests/e2e/p/permissionless_layer_one.go +++ b/tests/e2e/p/permissionless_layer_one.go @@ -248,19 +248,9 @@ var _ = e2e.DescribePChain("[Permissionless L1]", func() { }) advanceProposerVMPChainHeight := func() { - // We first must wait at least [RecentlyAcceptedWindowTTL] to ensure - // the next block will evict the prior block from the windower. + // We must wait at least [RecentlyAcceptedWindowTTL] to ensure the + // next block will reference the last accepted P-chain height. time.Sleep((5 * platformvmvalidators.RecentlyAcceptedWindowTTL) / 4) - - // Now we must: - // 1. issue a block which should include the old P-chain height. - // 2. issue a block which should include the new P-chain height. - for range 2 { - _, err = pWallet.IssueBaseTx(nil, tc.WithDefaultContext()) - require.NoError(err) - } - // Now that a block has been issued with the new P-chain height, the - // next block will use that height for warp message verification. } tc.By("advancing the proposervm P-chain height", advanceProposerVMPChainHeight)