diff --git a/vms/platformvm/block/executor/proposal_block_test.go b/vms/platformvm/block/executor/proposal_block_test.go index c0a597d7733..44f644f8fa5 100644 --- a/vms/platformvm/block/executor/proposal_block_test.go +++ b/vms/platformvm/block/executor/proposal_block_test.go @@ -93,6 +93,7 @@ func TestApricotProposalBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() + onParentAccept.EXPECT().NumActiveSubnetOnlyValidators().Return(0).AnyTimes() onParentAccept.EXPECT().GetCurrentStakerIterator().Return( iterator.FromSlice(&state.Staker{ @@ -165,6 +166,7 @@ func TestBanffProposalBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() + onParentAccept.EXPECT().NumActiveSubnetOnlyValidators().Return(0).AnyTimes() onParentAccept.EXPECT().GetCurrentSupply(constants.PrimaryNetworkID).Return(uint64(1000), nil).AnyTimes() env.blkManager.(*manager).blkIDToState[parentID] = &blockState{ diff --git a/vms/platformvm/block/executor/standard_block_test.go b/vms/platformvm/block/executor/standard_block_test.go index d9ad860d3d3..006287c0450 100644 --- a/vms/platformvm/block/executor/standard_block_test.go +++ b/vms/platformvm/block/executor/standard_block_test.go @@ -61,6 +61,7 @@ func TestApricotStandardBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() + onParentAccept.EXPECT().NumActiveSubnetOnlyValidators().Return(0).AnyTimes() // wrong height apricotChildBlk, err := block.NewApricotStandardBlock( @@ -137,6 +138,7 @@ func TestBanffStandardBlockTimeVerification(t *testing.T) { onParentAccept.EXPECT().GetFeeState().Return(gas.State{}).AnyTimes() onParentAccept.EXPECT().GetSoVExcess().Return(gas.Gas(0)).AnyTimes() onParentAccept.EXPECT().GetAccruedFees().Return(uint64(0)).AnyTimes() + onParentAccept.EXPECT().NumActiveSubnetOnlyValidators().Return(0).AnyTimes() txID := ids.GenerateTestID() utxo := &avax.UTXO{ diff --git a/vms/platformvm/block/executor/verifier_test.go b/vms/platformvm/block/executor/verifier_test.go index a076616701f..a04d93e7222 100644 --- a/vms/platformvm/block/executor/verifier_test.go +++ b/vms/platformvm/block/executor/verifier_test.go @@ -105,6 +105,7 @@ func TestVerifierVisitProposalBlock(t *testing.T) { parentOnAcceptState.EXPECT().GetFeeState().Return(gas.State{}).Times(2) parentOnAcceptState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(2) parentOnAcceptState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(2) + parentOnAcceptState.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(2) backend := &backend{ lastAccepted: parentID, @@ -338,6 +339,7 @@ func TestVerifierVisitStandardBlock(t *testing.T) { parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) parentState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + parentState.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) parentStatelessBlk.EXPECT().Height().Return(uint64(1)).Times(1) mempool.EXPECT().Remove(apricotBlk.Txs()).Times(1) @@ -601,6 +603,7 @@ func TestBanffAbortBlockTimestampChecks(t *testing.T) { s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) s.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(3) s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) + s.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) require.NoError(err) @@ -700,6 +703,7 @@ func TestBanffCommitBlockTimestampChecks(t *testing.T) { s.EXPECT().GetFeeState().Return(gas.State{}).Times(3) s.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(3) s.EXPECT().GetAccruedFees().Return(uint64(0)).Times(3) + s.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(3) onDecisionState, err := state.NewDiff(parentID, backend) require.NoError(err) @@ -817,6 +821,7 @@ func TestVerifierVisitStandardBlockWithDuplicateInputs(t *testing.T) { parentState.EXPECT().GetFeeState().Return(gas.State{}).Times(1) parentState.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) parentState.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + parentState.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) parentStatelessBlk.EXPECT().Parent().Return(grandParentID).Times(1) err = verifier.ApricotStandardBlock(blk) diff --git a/vms/platformvm/state/diff.go b/vms/platformvm/state/diff.go index da73854346e..4b8922f7449 100644 --- a/vms/platformvm/state/diff.go +++ b/vms/platformvm/state/diff.go @@ -35,15 +35,17 @@ type diff struct { parentID ids.ID stateVersions Versions - timestamp time.Time - feeState gas.State - sovExcess gas.Gas - accruedFees uint64 + timestamp time.Time + feeState gas.State + sovExcess gas.Gas + accruedFees uint64 + parentActiveSOVs int // Subnet ID --> supply of native asset of the subnet currentSupply map[ids.ID]uint64 expiryDiff *expiryDiff + sovDiff *subnetOnlyValidatorsDiff currentStakerDiffs diffStakers // map of subnetID -> nodeID -> total accrued delegatee rewards @@ -83,7 +85,9 @@ func NewDiff( feeState: parentState.GetFeeState(), sovExcess: parentState.GetSoVExcess(), accruedFees: parentState.GetAccruedFees(), + parentActiveSOVs: parentState.NumActiveSubnetOnlyValidators(), expiryDiff: newExpiryDiff(), + sovDiff: newSubnetOnlyValidatorsDiff(), subnetOwners: make(map[ids.ID]fx.Owner), subnetConversions: make(map[ids.ID]SubnetConversion), }, nil @@ -194,6 +198,70 @@ func (d *diff) DeleteExpiry(entry ExpiryEntry) { d.expiryDiff.DeleteExpiry(entry) } +func (d *diff) GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) { + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return nil, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + parentIterator, err := parentState.GetActiveSubnetOnlyValidatorsIterator() + if err != nil { + return nil, err + } + + return d.sovDiff.getActiveSubnetOnlyValidatorsIterator(parentIterator), nil +} + +func (d *diff) NumActiveSubnetOnlyValidators() int { + return d.parentActiveSOVs + d.sovDiff.numAddedActive +} + +func (d *diff) WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) { + if weight, modified := d.sovDiff.modifiedTotalWeight[subnetID]; modified { + return weight, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return 0, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + return parentState.WeightOfSubnetOnlyValidators(subnetID) +} + +func (d *diff) GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + if sov, modified := d.sovDiff.modified[validationID]; modified { + if sov.isDeleted() { + return SubnetOnlyValidator{}, database.ErrNotFound + } + return sov, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return SubnetOnlyValidator{}, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + return parentState.GetSubnetOnlyValidator(validationID) +} + +func (d *diff) HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) { + if has, modified := d.sovDiff.hasSubnetOnlyValidator(subnetID, nodeID); modified { + return has, nil + } + + parentState, ok := d.stateVersions.GetState(d.parentID) + if !ok { + return false, fmt.Errorf("%w: %s", ErrMissingParentState, d.parentID) + } + + return parentState.HasSubnetOnlyValidator(subnetID, nodeID) +} + +func (d *diff) PutSubnetOnlyValidator(sov SubnetOnlyValidator) error { + return d.sovDiff.putSubnetOnlyValidator(d, sov) +} + func (d *diff) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { // If the validator was modified in this diff, return the modified // validator. @@ -504,6 +572,26 @@ func (d *diff) Apply(baseState Chain) error { baseState.DeleteExpiry(entry) } } + // Ensure that all sov deletions happen before any sov additions. This + // ensures that a subnetID+nodeID pair that was deleted and then re-added in + // a single diff can't get reordered into the addition happening first; + // which would return an error. + for _, sov := range d.sovDiff.modified { + if !sov.isDeleted() { + continue + } + if err := baseState.PutSubnetOnlyValidator(sov); err != nil { + return err + } + } + for _, sov := range d.sovDiff.modified { + if sov.isDeleted() { + continue + } + if err := baseState.PutSubnetOnlyValidator(sov); err != nil { + return err + } + } for _, subnetValidatorDiffs := range d.currentStakerDiffs.validatorDiffs { for _, validatorDiff := range subnetValidatorDiffs { switch validatorDiff.validatorStatus { diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index 82e376c3b5f..9cc67ff8d9e 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -4,6 +4,7 @@ package state import ( + "math/rand" "testing" "time" @@ -282,6 +283,99 @@ func TestDiffExpiry(t *testing.T) { } } +func TestDiffSubnetOnlyValidatorsErrors(t *testing.T) { + sov := SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + Weight: 1, // Not removed + } + + tests := []struct { + name string + initialEndAccumulatedFee uint64 + sov SubnetOnlyValidator + expectedErr error + }{ + { + name: "mutate active constants", + initialEndAccumulatedFee: 1, + sov: SubnetOnlyValidator{ + ValidationID: sov.ValidationID, + NodeID: ids.GenerateTestNodeID(), + }, + expectedErr: ErrMutatedSubnetOnlyValidator, + }, + { + name: "mutate inactive constants", + initialEndAccumulatedFee: 0, + sov: SubnetOnlyValidator{ + ValidationID: sov.ValidationID, + NodeID: ids.GenerateTestNodeID(), + }, + expectedErr: ErrMutatedSubnetOnlyValidator, + }, + { + name: "conflicting legacy subnetID and nodeID pair", + initialEndAccumulatedFee: 1, + sov: SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + NodeID: defaultValidatorNodeID, + }, + expectedErr: ErrConflictingSubnetOnlyValidator, + }, + { + name: "duplicate active subnetID and nodeID pair", + initialEndAccumulatedFee: 1, + sov: SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + NodeID: sov.NodeID, + }, + expectedErr: ErrDuplicateSubnetOnlyValidator, + }, + { + name: "duplicate inactive subnetID and nodeID pair", + initialEndAccumulatedFee: 0, + sov: SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + NodeID: sov.NodeID, + }, + expectedErr: ErrDuplicateSubnetOnlyValidator, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + state := newTestState(t, memdb.New()) + + require.NoError(state.PutCurrentValidator(&Staker{ + TxID: ids.GenerateTestID(), + SubnetID: sov.SubnetID, + NodeID: defaultValidatorNodeID, + })) + + sov.EndAccumulatedFee = test.initialEndAccumulatedFee + require.NoError(state.PutSubnetOnlyValidator(sov)) + + d, err := NewDiffOn(state) + require.NoError(err) + + // Initialize subnetID, weight, and endAccumulatedFee as they are + // constant among all tests. + test.sov.SubnetID = sov.SubnetID + test.sov.Weight = 1 // Not removed + test.sov.EndAccumulatedFee = rand.Uint64() //#nosec G404 + err = d.PutSubnetOnlyValidator(test.sov) + require.ErrorIs(err, test.expectedErr) + + // The invalid addition should not have modified the diff. + assertChainsEqual(t, state, d) + }) + } +} + func TestDiffCurrentValidator(t *testing.T) { require := require.New(t) ctrl := gomock.NewController(t) @@ -292,6 +386,7 @@ func TestDiffCurrentValidator(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -328,6 +423,7 @@ func TestDiffPendingValidator(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -370,6 +466,7 @@ func TestDiffCurrentDelegator(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -415,6 +512,7 @@ func TestDiffPendingDelegator(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -554,6 +652,7 @@ func TestDiffTx(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -653,6 +752,7 @@ func TestDiffUTXO(t *testing.T) { state.EXPECT().GetFeeState().Return(gas.State{}).Times(1) state.EXPECT().GetSoVExcess().Return(gas.Gas(0)).Times(1) state.EXPECT().GetAccruedFees().Return(uint64(0)).Times(1) + state.EXPECT().NumActiveSubnetOnlyValidators().Return(0).Times(1) d, err := NewDiffOn(state) require.NoError(err) @@ -707,6 +807,18 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { ) } + expectedActiveSOVsIterator, expectedErr := expected.GetActiveSubnetOnlyValidatorsIterator() + actualActiveSOVsIterator, actualErr := actual.GetActiveSubnetOnlyValidatorsIterator() + require.Equal(expectedErr, actualErr) + if expectedErr == nil { + require.Equal( + iterator.ToSlice(expectedActiveSOVsIterator), + iterator.ToSlice(actualActiveSOVsIterator), + ) + } + + require.Equal(expected.NumActiveSubnetOnlyValidators(), actual.NumActiveSubnetOnlyValidators()) + expectedCurrentStakerIterator, expectedErr := expected.GetCurrentStakerIterator() actualCurrentStakerIterator, actualErr := actual.GetCurrentStakerIterator() require.Equal(expectedErr, actualErr) diff --git a/vms/platformvm/state/mock_chain.go b/vms/platformvm/state/mock_chain.go index 56c49592451..5a6d21b2dd1 100644 --- a/vms/platformvm/state/mock_chain.go +++ b/vms/platformvm/state/mock_chain.go @@ -204,6 +204,21 @@ func (mr *MockChainMockRecorder) GetAccruedFees() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockChain)(nil).GetAccruedFees)) } +// GetActiveSubnetOnlyValidatorsIterator mocks base method. +func (m *MockChain) GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveSubnetOnlyValidatorsIterator") + ret0, _ := ret[0].(iterator.Iterator[SubnetOnlyValidator]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveSubnetOnlyValidatorsIterator indicates an expected call of GetActiveSubnetOnlyValidatorsIterator. +func (mr *MockChainMockRecorder) GetActiveSubnetOnlyValidatorsIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSubnetOnlyValidatorsIterator", reflect.TypeOf((*MockChain)(nil).GetActiveSubnetOnlyValidatorsIterator)) +} + // GetCurrentDelegatorIterator mocks base method. func (m *MockChain) GetCurrentDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (iterator.Iterator[*Staker], error) { m.ctrl.T.Helper() @@ -382,6 +397,21 @@ func (mr *MockChainMockRecorder) GetSubnetConversion(subnetID any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetConversion", reflect.TypeOf((*MockChain)(nil).GetSubnetConversion), subnetID) } +// GetSubnetOnlyValidator mocks base method. +func (m *MockChain) GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubnetOnlyValidator", validationID) + ret0, _ := ret[0].(SubnetOnlyValidator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubnetOnlyValidator indicates an expected call of GetSubnetOnlyValidator. +func (mr *MockChainMockRecorder) GetSubnetOnlyValidator(validationID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetOnlyValidator", reflect.TypeOf((*MockChain)(nil).GetSubnetOnlyValidator), validationID) +} + // GetSubnetOwner mocks base method. func (m *MockChain) GetSubnetOwner(subnetID ids.ID) (fx.Owner, error) { m.ctrl.T.Helper() @@ -472,6 +502,35 @@ func (mr *MockChainMockRecorder) HasExpiry(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockChain)(nil).HasExpiry), arg0) } +// HasSubnetOnlyValidator mocks base method. +func (m *MockChain) HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSubnetOnlyValidator", subnetID, nodeID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasSubnetOnlyValidator indicates an expected call of HasSubnetOnlyValidator. +func (mr *MockChainMockRecorder) HasSubnetOnlyValidator(subnetID, nodeID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSubnetOnlyValidator", reflect.TypeOf((*MockChain)(nil).HasSubnetOnlyValidator), subnetID, nodeID) +} + +// NumActiveSubnetOnlyValidators mocks base method. +func (m *MockChain) NumActiveSubnetOnlyValidators() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NumActiveSubnetOnlyValidators") + ret0, _ := ret[0].(int) + return ret0 +} + +// NumActiveSubnetOnlyValidators indicates an expected call of NumActiveSubnetOnlyValidators. +func (mr *MockChainMockRecorder) NumActiveSubnetOnlyValidators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveSubnetOnlyValidators", reflect.TypeOf((*MockChain)(nil).NumActiveSubnetOnlyValidators)) +} + // PutCurrentDelegator mocks base method. func (m *MockChain) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -536,6 +595,20 @@ func (mr *MockChainMockRecorder) PutPendingValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutPendingValidator", reflect.TypeOf((*MockChain)(nil).PutPendingValidator), staker) } +// PutSubnetOnlyValidator mocks base method. +func (m *MockChain) PutSubnetOnlyValidator(sov SubnetOnlyValidator) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PutSubnetOnlyValidator", sov) + ret0, _ := ret[0].(error) + return ret0 +} + +// PutSubnetOnlyValidator indicates an expected call of PutSubnetOnlyValidator. +func (mr *MockChainMockRecorder) PutSubnetOnlyValidator(sov any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutSubnetOnlyValidator", reflect.TypeOf((*MockChain)(nil).PutSubnetOnlyValidator), sov) +} + // SetAccruedFees mocks base method. func (m *MockChain) SetAccruedFees(f uint64) { m.ctrl.T.Helper() @@ -633,3 +706,18 @@ func (mr *MockChainMockRecorder) SetTimestamp(tm any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTimestamp", reflect.TypeOf((*MockChain)(nil).SetTimestamp), tm) } + +// WeightOfSubnetOnlyValidators mocks base method. +func (m *MockChain) WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WeightOfSubnetOnlyValidators", subnetID) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WeightOfSubnetOnlyValidators indicates an expected call of WeightOfSubnetOnlyValidators. +func (mr *MockChainMockRecorder) WeightOfSubnetOnlyValidators(subnetID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WeightOfSubnetOnlyValidators", reflect.TypeOf((*MockChain)(nil).WeightOfSubnetOnlyValidators), subnetID) +} diff --git a/vms/platformvm/state/mock_diff.go b/vms/platformvm/state/mock_diff.go index b8362386af9..30f17a78122 100644 --- a/vms/platformvm/state/mock_diff.go +++ b/vms/platformvm/state/mock_diff.go @@ -218,6 +218,21 @@ func (mr *MockDiffMockRecorder) GetAccruedFees() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockDiff)(nil).GetAccruedFees)) } +// GetActiveSubnetOnlyValidatorsIterator mocks base method. +func (m *MockDiff) GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveSubnetOnlyValidatorsIterator") + ret0, _ := ret[0].(iterator.Iterator[SubnetOnlyValidator]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveSubnetOnlyValidatorsIterator indicates an expected call of GetActiveSubnetOnlyValidatorsIterator. +func (mr *MockDiffMockRecorder) GetActiveSubnetOnlyValidatorsIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSubnetOnlyValidatorsIterator", reflect.TypeOf((*MockDiff)(nil).GetActiveSubnetOnlyValidatorsIterator)) +} + // GetCurrentDelegatorIterator mocks base method. func (m *MockDiff) GetCurrentDelegatorIterator(subnetID ids.ID, nodeID ids.NodeID) (iterator.Iterator[*Staker], error) { m.ctrl.T.Helper() @@ -396,6 +411,21 @@ func (mr *MockDiffMockRecorder) GetSubnetConversion(subnetID any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetConversion", reflect.TypeOf((*MockDiff)(nil).GetSubnetConversion), subnetID) } +// GetSubnetOnlyValidator mocks base method. +func (m *MockDiff) GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubnetOnlyValidator", validationID) + ret0, _ := ret[0].(SubnetOnlyValidator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubnetOnlyValidator indicates an expected call of GetSubnetOnlyValidator. +func (mr *MockDiffMockRecorder) GetSubnetOnlyValidator(validationID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetOnlyValidator", reflect.TypeOf((*MockDiff)(nil).GetSubnetOnlyValidator), validationID) +} + // GetSubnetOwner mocks base method. func (m *MockDiff) GetSubnetOwner(subnetID ids.ID) (fx.Owner, error) { m.ctrl.T.Helper() @@ -486,6 +516,35 @@ func (mr *MockDiffMockRecorder) HasExpiry(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockDiff)(nil).HasExpiry), arg0) } +// HasSubnetOnlyValidator mocks base method. +func (m *MockDiff) HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSubnetOnlyValidator", subnetID, nodeID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasSubnetOnlyValidator indicates an expected call of HasSubnetOnlyValidator. +func (mr *MockDiffMockRecorder) HasSubnetOnlyValidator(subnetID, nodeID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSubnetOnlyValidator", reflect.TypeOf((*MockDiff)(nil).HasSubnetOnlyValidator), subnetID, nodeID) +} + +// NumActiveSubnetOnlyValidators mocks base method. +func (m *MockDiff) NumActiveSubnetOnlyValidators() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NumActiveSubnetOnlyValidators") + ret0, _ := ret[0].(int) + return ret0 +} + +// NumActiveSubnetOnlyValidators indicates an expected call of NumActiveSubnetOnlyValidators. +func (mr *MockDiffMockRecorder) NumActiveSubnetOnlyValidators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveSubnetOnlyValidators", reflect.TypeOf((*MockDiff)(nil).NumActiveSubnetOnlyValidators)) +} + // PutCurrentDelegator mocks base method. func (m *MockDiff) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -550,6 +609,20 @@ func (mr *MockDiffMockRecorder) PutPendingValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutPendingValidator", reflect.TypeOf((*MockDiff)(nil).PutPendingValidator), staker) } +// PutSubnetOnlyValidator mocks base method. +func (m *MockDiff) PutSubnetOnlyValidator(sov SubnetOnlyValidator) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PutSubnetOnlyValidator", sov) + ret0, _ := ret[0].(error) + return ret0 +} + +// PutSubnetOnlyValidator indicates an expected call of PutSubnetOnlyValidator. +func (mr *MockDiffMockRecorder) PutSubnetOnlyValidator(sov any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutSubnetOnlyValidator", reflect.TypeOf((*MockDiff)(nil).PutSubnetOnlyValidator), sov) +} + // SetAccruedFees mocks base method. func (m *MockDiff) SetAccruedFees(f uint64) { m.ctrl.T.Helper() @@ -647,3 +720,18 @@ func (mr *MockDiffMockRecorder) SetTimestamp(tm any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTimestamp", reflect.TypeOf((*MockDiff)(nil).SetTimestamp), tm) } + +// WeightOfSubnetOnlyValidators mocks base method. +func (m *MockDiff) WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WeightOfSubnetOnlyValidators", subnetID) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WeightOfSubnetOnlyValidators indicates an expected call of WeightOfSubnetOnlyValidators. +func (mr *MockDiffMockRecorder) WeightOfSubnetOnlyValidators(subnetID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WeightOfSubnetOnlyValidators", reflect.TypeOf((*MockDiff)(nil).WeightOfSubnetOnlyValidators), subnetID) +} diff --git a/vms/platformvm/state/mock_state.go b/vms/platformvm/state/mock_state.go index 5eab1b07ee9..b8d0713d99e 100644 --- a/vms/platformvm/state/mock_state.go +++ b/vms/platformvm/state/mock_state.go @@ -319,6 +319,21 @@ func (mr *MockStateMockRecorder) GetAccruedFees() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccruedFees", reflect.TypeOf((*MockState)(nil).GetAccruedFees)) } +// GetActiveSubnetOnlyValidatorsIterator mocks base method. +func (m *MockState) GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveSubnetOnlyValidatorsIterator") + ret0, _ := ret[0].(iterator.Iterator[SubnetOnlyValidator]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveSubnetOnlyValidatorsIterator indicates an expected call of GetActiveSubnetOnlyValidatorsIterator. +func (mr *MockStateMockRecorder) GetActiveSubnetOnlyValidatorsIterator() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSubnetOnlyValidatorsIterator", reflect.TypeOf((*MockState)(nil).GetActiveSubnetOnlyValidatorsIterator)) +} + // GetBlockIDAtHeight mocks base method. func (m *MockState) GetBlockIDAtHeight(height uint64) (ids.ID, error) { m.ctrl.T.Helper() @@ -601,6 +616,21 @@ func (mr *MockStateMockRecorder) GetSubnetIDs() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetIDs", reflect.TypeOf((*MockState)(nil).GetSubnetIDs)) } +// GetSubnetOnlyValidator mocks base method. +func (m *MockState) GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubnetOnlyValidator", validationID) + ret0, _ := ret[0].(SubnetOnlyValidator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubnetOnlyValidator indicates an expected call of GetSubnetOnlyValidator. +func (mr *MockStateMockRecorder) GetSubnetOnlyValidator(validationID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubnetOnlyValidator", reflect.TypeOf((*MockState)(nil).GetSubnetOnlyValidator), validationID) +} + // GetSubnetOwner mocks base method. func (m *MockState) GetSubnetOwner(subnetID ids.ID) (fx.Owner, error) { m.ctrl.T.Helper() @@ -707,6 +737,35 @@ func (mr *MockStateMockRecorder) HasExpiry(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasExpiry", reflect.TypeOf((*MockState)(nil).HasExpiry), arg0) } +// HasSubnetOnlyValidator mocks base method. +func (m *MockState) HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSubnetOnlyValidator", subnetID, nodeID) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasSubnetOnlyValidator indicates an expected call of HasSubnetOnlyValidator. +func (mr *MockStateMockRecorder) HasSubnetOnlyValidator(subnetID, nodeID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSubnetOnlyValidator", reflect.TypeOf((*MockState)(nil).HasSubnetOnlyValidator), subnetID, nodeID) +} + +// NumActiveSubnetOnlyValidators mocks base method. +func (m *MockState) NumActiveSubnetOnlyValidators() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NumActiveSubnetOnlyValidators") + ret0, _ := ret[0].(int) + return ret0 +} + +// NumActiveSubnetOnlyValidators indicates an expected call of NumActiveSubnetOnlyValidators. +func (mr *MockStateMockRecorder) NumActiveSubnetOnlyValidators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NumActiveSubnetOnlyValidators", reflect.TypeOf((*MockState)(nil).NumActiveSubnetOnlyValidators)) +} + // PutCurrentDelegator mocks base method. func (m *MockState) PutCurrentDelegator(staker *Staker) { m.ctrl.T.Helper() @@ -771,6 +830,20 @@ func (mr *MockStateMockRecorder) PutPendingValidator(staker any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutPendingValidator", reflect.TypeOf((*MockState)(nil).PutPendingValidator), staker) } +// PutSubnetOnlyValidator mocks base method. +func (m *MockState) PutSubnetOnlyValidator(sov SubnetOnlyValidator) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PutSubnetOnlyValidator", sov) + ret0, _ := ret[0].(error) + return ret0 +} + +// PutSubnetOnlyValidator indicates an expected call of PutSubnetOnlyValidator. +func (mr *MockStateMockRecorder) PutSubnetOnlyValidator(sov any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutSubnetOnlyValidator", reflect.TypeOf((*MockState)(nil).PutSubnetOnlyValidator), sov) +} + // ReindexBlocks mocks base method. func (m *MockState) ReindexBlocks(lock sync.Locker, log logging.Logger) error { m.ctrl.T.Helper() @@ -935,3 +1008,18 @@ func (mr *MockStateMockRecorder) UTXOIDs(addr, previous, limit any) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UTXOIDs", reflect.TypeOf((*MockState)(nil).UTXOIDs), addr, previous, limit) } + +// WeightOfSubnetOnlyValidators mocks base method. +func (m *MockState) WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WeightOfSubnetOnlyValidators", subnetID) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WeightOfSubnetOnlyValidators indicates an expected call of WeightOfSubnetOnlyValidators. +func (mr *MockStateMockRecorder) WeightOfSubnetOnlyValidators(subnetID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WeightOfSubnetOnlyValidators", reflect.TypeOf((*MockState)(nil).WeightOfSubnetOnlyValidators), subnetID) +} diff --git a/vms/platformvm/state/state.go b/vms/platformvm/state/state.go index ce4869792c6..f4ed7f26d7e 100644 --- a/vms/platformvm/state/state.go +++ b/vms/platformvm/state/state.go @@ -34,6 +34,7 @@ import ( "github.com/ava-labs/avalanchego/utils/hashing" "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/utils/timer" "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/avalanchego/vms/components/avax" @@ -86,6 +87,11 @@ var ( SupplyPrefix = []byte("supply") ChainPrefix = []byte("chain") ExpiryReplayProtectionPrefix = []byte("expiryReplayProtection") + SubnetOnlyValidatorsPrefix = []byte("subnetOnlyValidators") + WeightsPrefix = []byte("weights") + SubnetIDNodeIDPrefix = []byte("subnetIDNodeID") + ActivePrefix = []byte("active") + InactivePrefix = []byte("inactive") SingletonPrefix = []byte("singleton") TimestampKey = []byte("timestamp") @@ -103,6 +109,7 @@ var ( // execution. type Chain interface { Expiry + SubnetOnlyValidators Stakers avax.UTXOAdder avax.UTXOGetter @@ -318,6 +325,15 @@ type state struct { expiryDiff *expiryDiff expiryDB database.Database + activeSOVLookup map[ids.ID]SubnetOnlyValidator + activeSOVs *btree.BTreeG[SubnetOnlyValidator] + sovDiff *subnetOnlyValidatorsDiff + subnetOnlyValidatorsDB database.Database + weightsDB database.Database + subnetIDNodeIDDB database.Database + activeDB database.Database + inactiveDB database.Database + currentStakers *baseStakers pendingStakers *baseStakers @@ -509,6 +525,8 @@ func New( baseDB := versiondb.New(db) + subnetOnlyValidatorsDB := prefixdb.New(SubnetOnlyValidatorsPrefix, baseDB) + validatorsDB := prefixdb.New(ValidatorsPrefix, baseDB) currentValidatorsDB := prefixdb.New(CurrentPrefix, validatorsDB) @@ -635,6 +653,15 @@ func New( expiryDiff: newExpiryDiff(), expiryDB: prefixdb.New(ExpiryReplayProtectionPrefix, baseDB), + activeSOVLookup: make(map[ids.ID]SubnetOnlyValidator), + activeSOVs: btree.NewG(defaultTreeDegree, SubnetOnlyValidator.Less), + sovDiff: newSubnetOnlyValidatorsDiff(), + subnetOnlyValidatorsDB: subnetOnlyValidatorsDB, + weightsDB: prefixdb.New(WeightsPrefix, subnetOnlyValidatorsDB), + subnetIDNodeIDDB: prefixdb.New(SubnetIDNodeIDPrefix, subnetOnlyValidatorsDB), + activeDB: prefixdb.New(ActivePrefix, subnetOnlyValidatorsDB), + inactiveDB: prefixdb.New(InactivePrefix, subnetOnlyValidatorsDB), + currentStakers: newBaseStakers(), pendingStakers: newBaseStakers(), @@ -730,6 +757,64 @@ func (s *state) DeleteExpiry(entry ExpiryEntry) { s.expiryDiff.DeleteExpiry(entry) } +func (s *state) GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) { + return s.sovDiff.getActiveSubnetOnlyValidatorsIterator( + iterator.FromTree(s.activeSOVs), + ), nil +} + +func (s *state) NumActiveSubnetOnlyValidators() int { + return len(s.activeSOVLookup) + s.sovDiff.numAddedActive +} + +func (s *state) WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) { + if weight, modified := s.sovDiff.modifiedTotalWeight[subnetID]; modified { + return weight, nil + } + + // TODO: Add caching + return database.WithDefault(database.GetUInt64, s.weightsDB, subnetID[:], 0) +} + +func (s *state) GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + if sov, modified := s.sovDiff.modified[validationID]; modified { + if sov.isDeleted() { + return SubnetOnlyValidator{}, database.ErrNotFound + } + return sov, nil + } + + return s.getPersistedSubnetOnlyValidator(validationID) +} + +// getPersistedSubnetOnlyValidator returns the currently persisted +// SubnetOnlyValidator with the given validationID. It is guaranteed that any +// returned validator is either active or inactive. +func (s *state) getPersistedSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) { + if sov, ok := s.activeSOVLookup[validationID]; ok { + return sov, nil + } + + // TODO: Add caching + return getSubnetOnlyValidator(s.inactiveDB, validationID) +} + +func (s *state) HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) { + if has, modified := s.sovDiff.hasSubnetOnlyValidator(subnetID, nodeID); modified { + return has, nil + } + + // TODO: Add caching + key := make([]byte, len(subnetID)+len(nodeID)) + copy(key, subnetID[:]) + copy(key[len(subnetID):], nodeID[:]) + return s.subnetIDNodeIDDB.Has(key) +} + +func (s *state) PutSubnetOnlyValidator(sov SubnetOnlyValidator) error { + return s.sovDiff.putSubnetOnlyValidator(s, sov) +} + func (s *state) GetCurrentValidator(subnetID ids.ID, nodeID ids.NodeID) (*Staker, error) { return s.currentStakers.GetValidator(subnetID, nodeID) } @@ -1389,6 +1474,7 @@ func (s *state) load() error { return errors.Join( s.loadMetadata(), s.loadExpiry(), + s.loadActiveSubnetOnlyValidators(), s.loadCurrentValidators(), s.loadPendingValidators(), s.initValidatorSets(), @@ -1484,6 +1570,34 @@ func (s *state) loadExpiry() error { return nil } +func (s *state) loadActiveSubnetOnlyValidators() error { + it := s.activeDB.NewIterator() + defer it.Release() + + for it.Next() { + key := it.Key() + validationID, err := ids.ToID(key) + if err != nil { + return fmt.Errorf("failed to unmarshal ValidationID during load: %w", err) + } + + var ( + value = it.Value() + sov = SubnetOnlyValidator{ + ValidationID: validationID, + } + ) + if _, err := block.GenesisCodec.Unmarshal(value, &sov); err != nil { + return fmt.Errorf("failed to unmarshal SubnetOnlyValidator: %w", err) + } + + s.activeSOVLookup[validationID] = sov + s.activeSOVs.ReplaceOrInsert(sov) + } + + return nil +} + func (s *state) loadCurrentValidators() error { s.currentStakers = newBaseStakers() @@ -1738,9 +1852,53 @@ func (s *state) loadPendingValidators() error { ) } -// Invariant: initValidatorSets requires loadCurrentValidators to have already -// been called. +// Invariant: initValidatorSets requires loadActiveSubnetOnlyValidators and +// loadCurrentValidators to have already been called. func (s *state) initValidatorSets() error { + // Load ACP77 validators + for validationID, sov := range s.activeSOVLookup { + pk := bls.PublicKeyFromValidUncompressedBytes(sov.PublicKey) + if err := s.validators.AddStaker(sov.SubnetID, sov.NodeID, pk, validationID, sov.Weight); err != nil { + return err + } + } + + // Load inactive weights + it := s.weightsDB.NewIterator() + defer it.Release() + + for it.Next() { + subnetID, err := ids.ToID(it.Key()) + if err != nil { + return err + } + + totalWeight, err := database.ParseUInt64(it.Value()) + if err != nil { + return err + } + + activeWeight, err := s.validators.TotalWeight(subnetID) + if err != nil { + return err + } + + inactiveWeight, err := safemath.Sub(totalWeight, activeWeight) + if err != nil { + // This should never happen, as the total weight should always be at + // least the sum of the active weights. + return err + } + if inactiveWeight == 0 { + continue + } + + if err := s.validators.AddStaker(subnetID, ids.EmptyNodeID, nil, ids.Empty, inactiveWeight); err != nil { + return err + } + } + + // Load primary network and non-ACP77 validators primaryNetworkValidators := s.currentStakers.validators[constants.PrimaryNetworkID] for subnetID, subnetValidators := range s.currentStakers.validators { if s.validators.Count(subnetID) != 0 { @@ -1799,6 +1957,7 @@ func (s *state) write(updateValidators bool, height uint64) error { s.writeCurrentStakers(codecVersion), s.writePendingStakers(), s.WriteValidatorMetadata(s.currentValidatorList, s.currentSubnetValidatorList, codecVersion), // Must be called after writeCurrentStakers + s.writeSubnetOnlyValidators(), s.writeTXs(), s.writeRewardUTXOs(), s.writeUTXOs(), @@ -1815,6 +1974,11 @@ func (s *state) write(updateValidators bool, height uint64) error { func (s *state) Close() error { return errors.Join( s.expiryDB.Close(), + s.weightsDB.Close(), + s.subnetIDNodeIDDB.Close(), + s.activeDB.Close(), + s.inactiveDB.Close(), + s.subnetOnlyValidatorsDB.Close(), s.pendingSubnetValidatorBaseDB.Close(), s.pendingSubnetDelegatorBaseDB.Close(), s.pendingDelegatorBaseDB.Close(), @@ -2063,7 +2227,8 @@ func (s *state) getInheritedPublicKey(nodeID ids.NodeID) (*bls.PublicKey, error) // updateValidatorManager updates the validator manager with the pending // validator set changes. // -// This function must be called prior to writeCurrentStakers. +// This function must be called prior to writeCurrentStakers and +// writeSubnetOnlyValidators. func (s *state) updateValidatorManager(updateValidators bool) error { if !updateValidators { return nil @@ -2115,6 +2280,121 @@ func (s *state) updateValidatorManager(updateValidators bool) error { } } + // Perform SoV deletions: + var ( + sovChanges = s.sovDiff.modified + sovChangesApplied set.Set[ids.ID] + ) + for validationID, sov := range sovChanges { + if !sov.isDeleted() { + // Additions and modifications are handled in the next loops. + continue + } + + sovChangesApplied.Add(validationID) + + priorSOV, err := s.getPersistedSubnetOnlyValidator(validationID) + if err == database.ErrNotFound { + // Deleting a non-existent validator is a noop. This can happen if + // the validator was added and then immediately removed. + continue + } + if err != nil { + return err + } + + nodeID := ids.EmptyNodeID + if priorSOV.isActive() { + nodeID = priorSOV.NodeID + } + if err := s.validators.RemoveWeight(priorSOV.SubnetID, nodeID, priorSOV.Weight); err != nil { + return fmt.Errorf("failed to delete SoV validator: %w", err) + } + } + + // Perform modifications: + for validationID, sov := range sovChanges { + if sovChangesApplied.Contains(validationID) { + continue + } + + priorSOV, err := s.getPersistedSubnetOnlyValidator(validationID) + if err == database.ErrNotFound { + // New additions are handled in the next loop. + continue + } + if err != nil { + return err + } + + sovChangesApplied.Add(validationID) + + switch { + case priorSOV.isInactive() && sov.isActive(): + // This validator is being activated. + pk := bls.PublicKeyFromValidUncompressedBytes(sov.PublicKey) + err = errors.Join( + s.validators.RemoveWeight(sov.SubnetID, ids.EmptyNodeID, priorSOV.Weight), + s.validators.AddStaker(sov.SubnetID, sov.NodeID, pk, validationID, sov.Weight), + ) + case priorSOV.isActive() && sov.isInactive(): + // This validator is being deactivated. + inactiveWeight := s.validators.GetWeight(sov.SubnetID, ids.EmptyNodeID) + if inactiveWeight == 0 { + err = s.validators.AddStaker(sov.SubnetID, ids.EmptyNodeID, nil, ids.Empty, sov.Weight) + } else { + err = s.validators.AddWeight(sov.SubnetID, ids.EmptyNodeID, sov.Weight) + } + err = errors.Join( + err, + s.validators.RemoveWeight(sov.SubnetID, sov.NodeID, priorSOV.Weight), + ) + default: + // This validator's active status isn't changing. + nodeID := ids.EmptyNodeID + if sov.isActive() { + nodeID = sov.NodeID + } + if priorSOV.Weight < sov.Weight { + err = s.validators.AddWeight(sov.SubnetID, nodeID, sov.Weight-priorSOV.Weight) + } else if priorSOV.Weight > sov.Weight { + err = s.validators.RemoveWeight(sov.SubnetID, nodeID, priorSOV.Weight-sov.Weight) + } + } + if err != nil { + return err + } + } + + // Perform additions: + for validationID, sov := range sovChanges { + if sovChangesApplied.Contains(validationID) { + continue + } + + if sov.isActive() { + pk := bls.PublicKeyFromValidUncompressedBytes(sov.PublicKey) + if err := s.validators.AddStaker(sov.SubnetID, sov.NodeID, pk, validationID, sov.Weight); err != nil { + return fmt.Errorf("failed to add SoV validator: %w", err) + } + continue + } + + // This validator is inactive + var ( + inactiveWeight = s.validators.GetWeight(sov.SubnetID, ids.EmptyNodeID) + err error + ) + if inactiveWeight == 0 { + err = s.validators.AddStaker(sov.SubnetID, ids.EmptyNodeID, nil, ids.Empty, sov.Weight) + } else { + err = s.validators.AddWeight(sov.SubnetID, ids.EmptyNodeID, sov.Weight) + } + if err != nil { + return err + } + } + // Update the stake metrics totalWeight, err := s.validators.TotalWeight(constants.PrimaryNetworkID) if err != nil { @@ -2129,14 +2409,15 @@ func (s *state) updateValidatorManager(updateValidators bool) error { // writeValidatorDiffs writes the validator set diff contained by the pending // validator set changes to disk. // -// This function must be called prior to writeCurrentStakers. +// This function must be called prior to writeCurrentStakers and +// writeSubnetOnlyValidators. func (s *state) writeValidatorDiffs(height uint64) error { type validatorChanges struct { weightDiff ValidatorWeightDiff prevPublicKey []byte newPublicKey []byte } - changes := make(map[subnetIDNodeID]*validatorChanges) + changes := make(map[subnetIDNodeID]*validatorChanges, len(s.sovDiff.modified)) // Calculate the changes to the pre-ACP-77 validator set for subnetID, subnetDiffs := range s.currentStakers.validatorDiffs { @@ -2173,6 +2454,63 @@ func (s *state) writeValidatorDiffs(height uint64) error { } } + // Perform SoV deletions: + for validationID := range s.sovDiff.modified { + priorSOV, err := s.getPersistedSubnetOnlyValidator(validationID) + if err == database.ErrNotFound { + continue + } + if err != nil { + return err + } + + var ( + diff *validatorChanges + subnetIDNodeID = subnetIDNodeID{ + subnetID: priorSOV.SubnetID, + } + ) + if priorSOV.isActive() { + subnetIDNodeID.nodeID = priorSOV.NodeID + diff = getOrDefault(changes, subnetIDNodeID) + diff.prevPublicKey = priorSOV.PublicKey + } else { + subnetIDNodeID.nodeID = ids.EmptyNodeID + diff = getOrDefault(changes, subnetIDNodeID) + } + + if err := diff.weightDiff.Add(true, priorSOV.Weight); err != nil { + return err + } + } + + // Perform SoV additions: + for _, sov := range s.sovDiff.modified { + // If the validator is being removed, we shouldn't work to re-add it. + if sov.isDeleted() { + continue + } + + var ( + diff *validatorChanges + subnetIDNodeID = subnetIDNodeID{ + subnetID: sov.SubnetID, + } + ) + if sov.isActive() { + subnetIDNodeID.nodeID = sov.NodeID + diff = getOrDefault(changes, subnetIDNodeID) + diff.newPublicKey = sov.PublicKey + } else { + subnetIDNodeID.nodeID = ids.EmptyNodeID + diff = getOrDefault(changes, subnetIDNodeID) + } + + if err := diff.weightDiff.Add(false, sov.Weight); err != nil { + return err + } + } + // Write the changes to the database for subnetIDNodeID, diff := range changes { diffKey := marshalDiffKey(subnetIDNodeID.subnetID, height, subnetIDNodeID.nodeID) @@ -2198,6 +2536,16 @@ func (s *state) writeValidatorDiffs(height uint64) error { return nil } +func getOrDefault[K comparable, V any](m map[K]*V, k K) *V { + if v, ok := m[k]; ok { + return v + } + + v := new(V) + m[k] = v + return v +} + func (s *state) writeCurrentStakers(codecVersion uint16) error { for subnetID, validatorDiffs := range s.currentStakers.validatorDiffs { // Select db to write to @@ -2352,6 +2700,94 @@ func writePendingDiff( return nil } +func (s *state) writeSubnetOnlyValidators() error { + // Write modified weights: + for subnetID, weight := range s.sovDiff.modifiedTotalWeight { + if err := database.PutUInt64(s.weightsDB, subnetID[:], weight); err != nil { + return err + } + } + maps.Clear(s.sovDiff.modifiedTotalWeight) + + sovChanges := s.sovDiff.modified + // Perform deletions: + for validationID, sov := range sovChanges { + if !sov.isDeleted() { + // Additions and modifications are handled in the next loops. + continue + } + + // The next loops shouldn't consider this change. + delete(sovChanges, validationID) + + subnetIDNodeID := subnetIDNodeID{ + subnetID: sov.SubnetID, + nodeID: sov.NodeID, + } + subnetIDNodeIDKey := subnetIDNodeID.Marshal() + err := s.subnetIDNodeIDDB.Delete(subnetIDNodeIDKey) + if err != nil { + return err + } + + priorSOV, wasActive := s.activeSOVLookup[validationID] + if wasActive { + delete(s.activeSOVLookup, validationID) + s.activeSOVs.Delete(priorSOV) + err = deleteSubnetOnlyValidator(s.activeDB, validationID) + } else { + // It is technically possible for the validator not to exist on disk + // here, but that's fine as deleting an entry that doesn't exist is + // a noop. + err = deleteSubnetOnlyValidator(s.inactiveDB, validationID) + } + if err != nil { + return err + } + } + + // Perform modifications and additions: + for validationID, sov := range sovChanges { + priorSOV, err := s.getPersistedSubnetOnlyValidator(validationID) + switch err { + case nil: + // This is modifying an existing validator + if priorSOV.isActive() { + delete(s.activeSOVLookup, validationID) + s.activeSOVs.Delete(priorSOV) + err = deleteSubnetOnlyValidator(s.activeDB, validationID) + } else { + err = deleteSubnetOnlyValidator(s.inactiveDB, validationID) + } + case database.ErrNotFound: + // This is a new validator + subnetIDNodeID := subnetIDNodeID{ + subnetID: sov.SubnetID, + nodeID: sov.NodeID, + } + subnetIDNodeIDKey := subnetIDNodeID.Marshal() + err = s.subnetIDNodeIDDB.Put(subnetIDNodeIDKey, validationID[:]) + } + if err != nil { + return err + } + + if sov.isActive() { + s.activeSOVLookup[validationID] = sov + s.activeSOVs.ReplaceOrInsert(sov) + err = putSubnetOnlyValidator(s.activeDB, sov) + } else { + err = putSubnetOnlyValidator(s.inactiveDB, sov) + } + if err != nil { + return err + } + } + + s.sovDiff = newSubnetOnlyValidatorsDiff() + return nil +} + func (s *state) writeTXs() error { for txID, txStatus := range s.addedTxs { txID := txID diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 85df176ce42..f4f50ebdf05 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -5,6 +5,7 @@ package state import ( "context" + "maps" "math" "math/rand" "sync" @@ -23,6 +24,7 @@ import ( "github.com/ava-labs/avalanchego/snow/choices" "github.com/ava-labs/avalanchego/snow/validators" "github.com/ava-labs/avalanchego/upgrade/upgradetest" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/iterator" @@ -1490,3 +1492,441 @@ func TestStateExpiryCommitAndLoad(t *testing.T) { require.NoError(err) require.False(has) } + +func TestSubnetOnlyValidators(t *testing.T) { + sov := SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + } + + sk, err := bls.NewSecretKey() + require.NoError(t, err) + pk := bls.PublicFromSecretKey(sk) + pkBytes := bls.PublicKeyToUncompressedBytes(pk) + + otherSK, err := bls.NewSecretKey() + require.NoError(t, err) + otherPK := bls.PublicFromSecretKey(otherSK) + otherPKBytes := bls.PublicKeyToUncompressedBytes(otherPK) + + tests := []struct { + name string + initial []SubnetOnlyValidator + sovs []SubnetOnlyValidator + }{ + { + name: "empty noop", + }, + { + name: "initially active not modified", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + }, + { + name: "initially inactive not modified", + initial: []SubnetOnlyValidator{ + { + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 0, // Inactive + }, + }, + }, + { + name: "initially active removed", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 0, // Removed + }, + }, + }, + { + name: "initially inactive removed", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 0, // Inactive + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 0, // Removed + }, + }, + }, + { + name: "increase active weight", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 2, // Increased + EndAccumulatedFee: 1, // Active + }, + }, + }, + { + name: "deactivate", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 0, // Inactive + }, + }, + }, + { + name: "reactivate", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 0, // Inactive + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + }, + { + name: "update multiple times", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 2, // Not removed + EndAccumulatedFee: 1, // Inactive + }, + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 3, // Not removed + EndAccumulatedFee: 1, // Inactive + }, + }, + }, + { + name: "change validationID", + initial: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + }, + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 0, // Removed + }, + { + ValidationID: ids.GenerateTestID(), + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: otherPKBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Inactive + }, + }, + }, + { + name: "added and removed", + sovs: []SubnetOnlyValidator{ + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 1, // Not removed + EndAccumulatedFee: 1, // Active + }, + { + ValidationID: sov.ValidationID, + SubnetID: sov.SubnetID, + NodeID: sov.NodeID, + PublicKey: pkBytes, + Weight: 0, // Removed + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + db := memdb.New() + state := newTestState(t, db) + + var ( + initialSOVs = make(map[ids.ID]SubnetOnlyValidator) + subnetIDs set.Set[ids.ID] + ) + for _, sov := range test.initial { + sov.RemainingBalanceOwner = []byte{} + sov.DeactivationOwner = []byte{} + + require.NoError(state.PutSubnetOnlyValidator(sov)) + initialSOVs[sov.ValidationID] = sov + subnetIDs.Add(sov.SubnetID) + } + + state.SetHeight(0) + require.NoError(state.Commit()) + + d, err := NewDiffOn(state) + require.NoError(err) + + expectedSOVs := maps.Clone(initialSOVs) + for _, sov := range test.sovs { + sov.RemainingBalanceOwner = []byte{} + sov.DeactivationOwner = []byte{} + + require.NoError(d.PutSubnetOnlyValidator(sov)) + expectedSOVs[sov.ValidationID] = sov + subnetIDs.Add(sov.SubnetID) + } + + verifyChain := func(chain Chain) { + for _, expectedSOV := range expectedSOVs { + if expectedSOV.Weight != 0 { + continue + } + + sov, err := chain.GetSubnetOnlyValidator(expectedSOV.ValidationID) + require.ErrorIs(err, database.ErrNotFound) + require.Zero(sov) + } + + var ( + weights = make(map[ids.ID]uint64) + expectedActive []SubnetOnlyValidator + ) + for _, expectedSOV := range expectedSOVs { + if expectedSOV.Weight == 0 { + continue + } + + sov, err := chain.GetSubnetOnlyValidator(expectedSOV.ValidationID) + require.NoError(err) + require.Equal(expectedSOV, sov) + + has, err := chain.HasSubnetOnlyValidator(expectedSOV.SubnetID, expectedSOV.NodeID) + require.NoError(err) + require.True(has) + + weights[sov.SubnetID] += sov.Weight + if expectedSOV.isActive() { + expectedActive = append(expectedActive, expectedSOV) + } + } + utils.Sort(expectedActive) + + activeIterator, err := chain.GetActiveSubnetOnlyValidatorsIterator() + require.NoError(err) + require.Equal( + expectedActive, + iterator.ToSlice(activeIterator), + ) + + require.Equal(len(expectedActive), chain.NumActiveSubnetOnlyValidators()) + + for subnetID, expectedWeight := range weights { + weight, err := chain.WeightOfSubnetOnlyValidators(subnetID) + require.NoError(err) + require.Equal(expectedWeight, weight) + } + } + + verifyChain(d) + require.NoError(d.Apply(state)) + verifyChain(d) + verifyChain(state) + assertChainsEqual(t, state, d) + + state.SetHeight(1) + require.NoError(state.Commit()) + verifyChain(d) + verifyChain(state) + assertChainsEqual(t, state, d) + + sovsToValidatorSet := func( + sovs map[ids.ID]SubnetOnlyValidator, + subnetID ids.ID, + ) map[ids.NodeID]*validators.GetValidatorOutput { + validatorSet := make(map[ids.NodeID]*validators.GetValidatorOutput) + for _, sov := range sovs { + if sov.SubnetID != subnetID || sov.Weight == 0 { + continue + } + + nodeID := sov.NodeID + publicKey := bls.PublicKeyFromValidUncompressedBytes(sov.PublicKey) + // Inactive validators are combined into a single validator + // with the empty ID. + if sov.EndAccumulatedFee == 0 { + nodeID = ids.EmptyNodeID + publicKey = nil + } + + vdr, ok := validatorSet[nodeID] + if !ok { + vdr = &validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: publicKey, + } + validatorSet[nodeID] = vdr + } + vdr.Weight += sov.Weight + } + return validatorSet + } + + reloadedState := newTestState(t, db) + for subnetID := range subnetIDs { + expectedEndValidatorSet := sovsToValidatorSet(expectedSOVs, subnetID) + endValidatorSet := state.validators.GetMap(subnetID) + require.Equal(expectedEndValidatorSet, endValidatorSet) + + reloadedEndValidatorSet := reloadedState.validators.GetMap(subnetID) + require.Equal(expectedEndValidatorSet, reloadedEndValidatorSet) + + require.NoError(state.ApplyValidatorWeightDiffs(context.Background(), endValidatorSet, 1, 1, subnetID)) + require.NoError(state.ApplyValidatorPublicKeyDiffs(context.Background(), endValidatorSet, 1, 1, subnetID)) + + initialValidatorSet := sovsToValidatorSet(initialSOVs, subnetID) + require.Equal(initialValidatorSet, endValidatorSet) + } + }) + } +} + +func TestSubnetOnlyValidatorAfterLegacyRemoval(t *testing.T) { + require := require.New(t) + + db := memdb.New() + state := newTestState(t, db) + + legacyStaker := &Staker{ + TxID: ids.GenerateTestID(), + NodeID: defaultValidatorNodeID, + PublicKey: nil, + SubnetID: ids.GenerateTestID(), + Weight: 1, + StartTime: genesistest.DefaultValidatorStartTime, + EndTime: genesistest.DefaultValidatorEndTime, + PotentialReward: 0, + } + require.NoError(state.PutCurrentValidator(legacyStaker)) + + state.SetHeight(1) + require.NoError(state.Commit()) + + state.DeleteCurrentValidator(legacyStaker) + + sov := SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: legacyStaker.SubnetID, + NodeID: legacyStaker.NodeID, + PublicKey: utils.RandomBytes(bls.PublicKeyLen), + RemainingBalanceOwner: utils.RandomBytes(32), + DeactivationOwner: utils.RandomBytes(32), + StartTime: 1, + Weight: 2, + MinNonce: 3, + EndAccumulatedFee: 4, + } + require.NoError(state.PutSubnetOnlyValidator(sov)) + + state.SetHeight(2) + require.NoError(state.Commit()) +} diff --git a/vms/platformvm/state/subnet_only_validator.go b/vms/platformvm/state/subnet_only_validator.go index 1da34a50299..ddaa5eb5a50 100644 --- a/vms/platformvm/state/subnet_only_validator.go +++ b/vms/platformvm/state/subnet_only_validator.go @@ -13,6 +13,8 @@ import ( "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils" + "github.com/ava-labs/avalanchego/utils/iterator" + "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/vms/platformvm/block" ) @@ -20,9 +22,43 @@ var ( _ btree.LessFunc[SubnetOnlyValidator] = SubnetOnlyValidator.Less _ utils.Sortable[SubnetOnlyValidator] = SubnetOnlyValidator{} - ErrMutatedSubnetOnlyValidator = errors.New("subnet only validator contains mutated constant fields") + ErrMutatedSubnetOnlyValidator = errors.New("subnet only validator contains mutated constant fields") + ErrConflictingSubnetOnlyValidator = errors.New("subnet only validator contains conflicting subnetID + nodeID pair") + ErrDuplicateSubnetOnlyValidator = errors.New("subnet only validator contains duplicate subnetID + nodeID pair") ) +type SubnetOnlyValidators interface { + // GetActiveSubnetOnlyValidatorsIterator returns an iterator of all the + // active subnet only validators in increasing order of EndAccumulatedFee. + GetActiveSubnetOnlyValidatorsIterator() (iterator.Iterator[SubnetOnlyValidator], error) + + // NumActiveSubnetOnlyValidators returns the number of currently active + // subnet only validators. + NumActiveSubnetOnlyValidators() int + + // WeightOfSubnetOnlyValidators returns the total active and inactive weight + // of subnet only validators on [subnetID]. + WeightOfSubnetOnlyValidators(subnetID ids.ID) (uint64, error) + + // GetSubnetOnlyValidator returns the validator with [validationID] if it + // exists. If the validator does not exist, [err] will equal + // [database.ErrNotFound]. + GetSubnetOnlyValidator(validationID ids.ID) (SubnetOnlyValidator, error) + + // HasSubnetOnlyValidator returns the validator with [validationID] if it + // exists. + HasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, error) + + // PutSubnetOnlyValidator inserts [sov] as a validator. + // + // If inserting this validator attempts to modify any of the constant fields + // of the subnet only validator struct, an error will be returned. + // + // If inserting this validator would cause the mapping of subnetID+nodeID to + // validationID to be non-unique, an error will be returned. + PutSubnetOnlyValidator(sov SubnetOnlyValidator) error +} + // SubnetOnlyValidator defines an ACP-77 validator. For a given ValidationID, it // is expected for SubnetID, NodeID, PublicKey, RemainingBalanceOwner, and // StartTime to be constant. @@ -102,6 +138,18 @@ func (v SubnetOnlyValidator) constantsAreUnmodified(o SubnetOnlyValidator) bool v.StartTime == o.StartTime } +func (v SubnetOnlyValidator) isDeleted() bool { + return v.Weight == 0 +} + +func (v SubnetOnlyValidator) isActive() bool { + return v.Weight != 0 && v.EndAccumulatedFee != 0 +} + +func (v SubnetOnlyValidator) isInactive() bool { + return v.Weight != 0 && v.EndAccumulatedFee == 0 +} + func getSubnetOnlyValidator(db database.KeyValueReader, validationID ids.ID) (SubnetOnlyValidator, error) { bytes, err := db.Get(validationID[:]) if err != nil { @@ -128,3 +176,119 @@ func putSubnetOnlyValidator(db database.KeyValueWriter, vdr SubnetOnlyValidator) func deleteSubnetOnlyValidator(db database.KeyValueDeleter, validationID ids.ID) error { return db.Delete(validationID[:]) } + +type subnetOnlyValidatorsDiff struct { + numAddedActive int // May be negative + modifiedTotalWeight map[ids.ID]uint64 // subnetID -> totalWeight + modified map[ids.ID]SubnetOnlyValidator + modifiedHasNodeIDs map[subnetIDNodeID]bool + active *btree.BTreeG[SubnetOnlyValidator] +} + +func newSubnetOnlyValidatorsDiff() *subnetOnlyValidatorsDiff { + return &subnetOnlyValidatorsDiff{ + modifiedTotalWeight: make(map[ids.ID]uint64), + modified: make(map[ids.ID]SubnetOnlyValidator), + modifiedHasNodeIDs: make(map[subnetIDNodeID]bool), + active: btree.NewG(defaultTreeDegree, SubnetOnlyValidator.Less), + } +} + +// getActiveSubnetOnlyValidatorsIterator takes in the parent iterator, removes +// all modified validators, and then adds all modified active validators. +func (d *subnetOnlyValidatorsDiff) getActiveSubnetOnlyValidatorsIterator(parentIterator iterator.Iterator[SubnetOnlyValidator]) iterator.Iterator[SubnetOnlyValidator] { + return iterator.Merge( + SubnetOnlyValidator.Less, + iterator.Filter(parentIterator, func(sov SubnetOnlyValidator) bool { + _, ok := d.modified[sov.ValidationID] + return ok + }), + iterator.FromTree(d.active), + ) +} + +func (d *subnetOnlyValidatorsDiff) hasSubnetOnlyValidator(subnetID ids.ID, nodeID ids.NodeID) (bool, bool) { + subnetIDNodeID := subnetIDNodeID{ + subnetID: subnetID, + nodeID: nodeID, + } + has, modified := d.modifiedHasNodeIDs[subnetIDNodeID] + return has, modified +} + +func (d *subnetOnlyValidatorsDiff) putSubnetOnlyValidator(state Chain, sov SubnetOnlyValidator) error { + var ( + prevWeight uint64 + prevActive bool + newActive = sov.isActive() + ) + switch priorSOV, err := state.GetSubnetOnlyValidator(sov.ValidationID); err { + case nil: + if !priorSOV.constantsAreUnmodified(sov) { + return ErrMutatedSubnetOnlyValidator + } + + prevWeight = priorSOV.Weight + prevActive = priorSOV.isActive() + case database.ErrNotFound: + // Verify that there is not a legacy subnet validator with the same + // subnetID+nodeID as this L1 validator. + _, err := state.GetCurrentValidator(sov.SubnetID, sov.NodeID) + if err == nil { + return ErrConflictingSubnetOnlyValidator + } + if err != database.ErrNotFound { + return err + } + + has, err := state.HasSubnetOnlyValidator(sov.SubnetID, sov.NodeID) + if err != nil { + return err + } + if has { + return ErrDuplicateSubnetOnlyValidator + } + default: + return err + } + + if prevWeight != sov.Weight { + weight, err := state.WeightOfSubnetOnlyValidators(sov.SubnetID) + if err != nil { + return err + } + + weight, err = math.Sub(weight, prevWeight) + if err != nil { + return err + } + weight, err = math.Add(weight, sov.Weight) + if err != nil { + return err + } + + d.modifiedTotalWeight[sov.SubnetID] = weight + } + + switch { + case prevActive && !newActive: + d.numAddedActive-- + case !prevActive && newActive: + d.numAddedActive++ + } + + if prevSOV, ok := d.modified[sov.ValidationID]; ok { + d.active.Delete(prevSOV) + } + d.modified[sov.ValidationID] = sov + + subnetIDNodeID := subnetIDNodeID{ + subnetID: sov.SubnetID, + nodeID: sov.NodeID, + } + d.modifiedHasNodeIDs[subnetIDNodeID] = !sov.isDeleted() + if sov.isActive() { + d.active.ReplaceOrInsert(sov) + } + return nil +}