From 2850e9780754aac489060df593a13114abc3459c Mon Sep 17 00:00:00 2001 From: Kris Jacque Date: Tue, 30 Apr 2024 11:32:55 -0600 Subject: [PATCH] DAOS-15549 control: Pause group update if provider is changed (#14182) If the system fabric provider is updated, it is possible that group updates could be sent with stale information (a mix of old and new URIs). Engines may attempt to begin communicating with peers whose URIs have not been updated to the new provider. We need to prevent this. This patch temporarily pauses group updates when it detects a fabric provider change in order to allow all known ranks to rejoin. Afterward it triggers the group update to the engines. Signed-off-by: Kris Jacque --- src/control/server/mgmt_system.go | 67 +++++ src/control/server/mgmt_system_test.go | 379 ++++++++++++++++++++++++- src/control/server/server.go | 1 + src/control/server/util_test.go | 18 +- 4 files changed, 447 insertions(+), 18 deletions(-) diff --git a/src/control/server/mgmt_system.go b/src/control/server/mgmt_system.go index 1ab741438b2..8df7861ed5c 100644 --- a/src/control/server/mgmt_system.go +++ b/src/control/server/mgmt_system.go @@ -14,6 +14,7 @@ import ( "path/filepath" "reflect" "runtime" + "strconv" "strings" "time" @@ -41,6 +42,7 @@ import ( ) const fabricProviderProp = "fabric_providers" +const groupUpdatePauseProp = "group_update_paused" // GetAttachInfo handles a request to retrieve a map of ranks to fabric URIs, in addition // to client network autoconfiguration hints. @@ -215,6 +217,13 @@ func (svc *mgmtSvc) join(ctx context.Context, req *mgmtpb.JoinReq, peerAddr *net MapVersion: joinResponse.MapVersion, } + if svc.isGroupUpdatePaused() && svc.allRanksJoined() { + if err := svc.resumeGroupUpdate(); err != nil { + svc.log.Errorf("failed to resume group update: %s", err.Error()) + } + // join loop will trigger a new group update after this + } + // If the rank is local to the MS leader, then we need to wire up at least // one in order to perform a CaRT group update. if common.IsLocalAddr(peerAddr) && req.Idx == 0 { @@ -234,6 +243,29 @@ func (svc *mgmtSvc) join(ctx context.Context, req *mgmtpb.JoinReq, peerAddr *net return resp, nil } +// allRanksJoined checks whether all ranks that the system knows about, and that are not admin +// excluded, are joined. +// +// NB: This checks the state to determine if the rank is joined. There is a potential hole here, +// in a case where the system was killed with ranks in the joined state, rather than stopping the +// ranks first. In that case we may fire this off too early. +func (svc *mgmtSvc) allRanksJoined() bool { + var total int + var joined int + var err error + if total, err = svc.sysdb.MemberCount(); err != nil { + svc.log.Errorf("failed to get total member count: %s", err) + return false + } + + if joined, err = svc.sysdb.MemberCount(system.MemberStateJoined, system.MemberStateAdminExcluded); err != nil { + svc.log.Errorf("failed to get joined member count: %s", err) + return false + } + + return total == joined +} + func (svc *mgmtSvc) checkReqFabricProvider(req *mgmtpb.JoinReq, peerAddr *net.TCPAddr, publisher events.Publisher) error { joinProv, err := getProviderFromURI(req.Uri) if err != nil { @@ -272,6 +304,27 @@ func (svc *mgmtSvc) setFabricProviders(val string) error { return system.SetMgmtProperty(svc.sysdb, fabricProviderProp, val) } +func (svc *mgmtSvc) isGroupUpdatePaused() bool { + propStr, err := system.GetMgmtProperty(svc.sysdb, groupUpdatePauseProp) + if err != nil { + return false + } + result, err := strconv.ParseBool(propStr) + if err != nil { + svc.log.Errorf("invalid value for mgmt prop %q: %s", groupUpdatePauseProp, err.Error()) + return false + } + return result +} + +func (svc *mgmtSvc) pauseGroupUpdate() error { + return system.SetMgmtProperty(svc.sysdb, groupUpdatePauseProp, "true") +} + +func (svc *mgmtSvc) resumeGroupUpdate() error { + return system.SetMgmtProperty(svc.sysdb, groupUpdatePauseProp, "false") +} + func (svc *mgmtSvc) updateFabricProviders(provList []string, publisher events.Publisher) error { provStr := strings.Join(provList, ",") @@ -298,7 +351,16 @@ func (svc *mgmtSvc) updateFabricProviders(provList []string, publisher events.Pu curProv, provStr, numJoined) } + if err := svc.pauseGroupUpdate(); err != nil { + return errors.Wrapf(err, "unable to pause group update before provider change") + } + if err := svc.setFabricProviders(provStr); err != nil { + if guErr := svc.resumeGroupUpdate(); guErr != nil { + // something is very wrong if this happens + svc.log.Errorf("unable to resume group update after provider change failed: %s", guErr.Error()) + } + return errors.Wrapf(err, "changing fabric provider prop") } publisher.Publish(newFabricProvChangedEvent(curProv, provStr)) @@ -326,6 +388,11 @@ func (svc *mgmtSvc) reqGroupUpdate(ctx context.Context, sync bool) { // NB: This method must not be called concurrently, as out-of-order // group updates may trigger engine assertions. func (svc *mgmtSvc) doGroupUpdate(ctx context.Context, forced bool) error { + if svc.isGroupUpdatePaused() { + svc.log.Debugf("group update requested (force: %v), but temporarily paused", forced) + return nil + } + if forced { if err := svc.sysdb.IncMapVer(); err != nil { return err diff --git a/src/control/server/mgmt_system_test.go b/src/control/server/mgmt_system_test.go index 00d959f928a..8632fb45ffe 100644 --- a/src/control/server/mgmt_system_test.go +++ b/src/control/server/mgmt_system_test.go @@ -31,6 +31,7 @@ import ( "github.com/daos-stack/daos/src/control/drpc" "github.com/daos-stack/daos/src/control/events" "github.com/daos-stack/daos/src/control/lib/control" + "github.com/daos-stack/daos/src/control/lib/daos" "github.com/daos-stack/daos/src/control/lib/hardware" "github.com/daos-stack/daos/src/control/lib/ranklist" "github.com/daos-stack/daos/src/control/logging" @@ -1928,11 +1929,12 @@ func TestServer_MgmtSvc_Join(t *testing.T) { newProviderMember.PrimaryFabricURI = fmt.Sprintf("verbs://%s", test.MockHostAddr(1)) for name, tc := range map[string]struct { - req *mgmtpb.JoinReq - guResp *mgmtpb.GroupUpdateResp - expGuReq *mgmtpb.GroupUpdateReq - expResp *mgmtpb.JoinResp - expErr error + req *mgmtpb.JoinReq + pauseGroupUpdate bool + guResp *mgmtpb.GroupUpdateResp + expGuReq *mgmtpb.GroupUpdateReq + expResp *mgmtpb.JoinResp + expErr error }{ "bad sys": { req: &mgmtpb.JoinReq{ @@ -2015,6 +2017,7 @@ func TestServer_MgmtSvc_Join(t *testing.T) { }, }, "provider doesn't match": { + pauseGroupUpdate: true, req: &mgmtpb.JoinReq{ Rank: curMember.Rank.Uint32(), Uuid: curMember.UUID.String(), @@ -2033,6 +2036,31 @@ func TestServer_MgmtSvc_Join(t *testing.T) { }, expErr: errors.New("does not match"), }, + "group update resumed": { + pauseGroupUpdate: true, + req: &mgmtpb.JoinReq{ + Rank: curMember.Rank.Uint32(), + Uuid: curMember.UUID.String(), + Uri: curMember.PrimaryFabricURI, + Incarnation: curMember.Incarnation + 1, + }, + expGuReq: &mgmtpb.GroupUpdateReq{ + MapVersion: 3, + Engines: []*mgmtpb.GroupUpdateReq_Engine{ + { + Rank: curMember.Rank.Uint32(), + Uri: curMember.PrimaryFabricURI, + Incarnation: curMember.Incarnation + 1, + }, + }, + }, + expResp: &mgmtpb.JoinResp{ + Status: 0, + Rank: curMember.Rank.Uint32(), + State: mgmtpb.JoinResp_IN, + MapVersion: 2, + }, + }, "new host (non local)": { req: &mgmtpb.JoinReq{ Rank: uint32(ranklist.NilRank), @@ -2094,6 +2122,9 @@ func TestServer_MgmtSvc_Join(t *testing.T) { curCopy.Rank = ranklist.NilRank // ensure that db.data.NextRank is incremented svc := mgmtSystemTestSetup(t, log, system.Members{curCopy}, nil) + if tc.pauseGroupUpdate { + svc.pauseGroupUpdate() + } if tc.req.Sys == "" { tc.req.Sys = build.DefaultSystemName @@ -2165,14 +2196,149 @@ func TestServer_MgmtSvc_Join(t *testing.T) { } } +func TestServer_MgmtSvc_doGroupUpdate(t *testing.T) { + mockMembers := func(t *testing.T, count int, state string) system.Members { + result := system.Members{} + for i := 0; i < count; i++ { + result = append(result, mockMember(t, int32(i), int32(i), state)) + } + return result + } + + defaultMemberCount := 3 + defaultTestMS := func(t *testing.T, l logging.Logger) *mgmtSvc { + return mgmtSystemTestSetup(t, l, mockMembers(t, defaultMemberCount, "joined"), nil) + } + + uri := func(idx int) string { + return "tcp://" + test.MockHostAddr(int32(idx)).String() + } + + getGroupUpdateReq := func(mapVer, count int) *mgmtpb.GroupUpdateReq { + req := &mgmtpb.GroupUpdateReq{ + MapVersion: uint32(mapVer), + } + for i := 0; i < count; i++ { + req.Engines = append(req.Engines, &mgmtpb.GroupUpdateReq_Engine{ + Rank: uint32(i), + Uri: uri(i), + Incarnation: uint64(i), + }) + } + return req + } + + getDefaultGroupUpdateReq := func() *mgmtpb.GroupUpdateReq { + return getGroupUpdateReq(defaultMemberCount, defaultMemberCount) + } + + for name, tc := range map[string]struct { + getSvc func(*testing.T, logging.Logger) *mgmtSvc + force bool + expDrpcReq *mgmtpb.GroupUpdateReq + drpcResp *mgmtpb.GroupUpdateResp + drpcErr error + expErr error + }{ + "group update paused": { + getSvc: func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := defaultTestMS(t, l) + svc.pauseGroupUpdate() + return svc + }, + }, + "group update paused with force": { + getSvc: func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := defaultTestMS(t, l) + svc.pauseGroupUpdate() + return svc + }, + force: true, + }, + "no ranks": { + getSvc: func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := mgmtSystemTestSetup(t, l, system.Members{}, nil) + return svc + }, + expErr: system.ErrEmptyGroupMap, + }, + "map version already updated": { + getSvc: func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := defaultTestMS(t, l) + svc.lastMapVer = uint32(defaultMemberCount) + return svc + }, + }, + "drpc failed": { + drpcErr: errors.New("mock drpc"), + expDrpcReq: getDefaultGroupUpdateReq(), + expErr: errors.New("mock drpc"), + }, + "drpc bad status": { + drpcResp: &mgmtpb.GroupUpdateResp{Status: daos.MiscError.Int32()}, + expDrpcReq: getDefaultGroupUpdateReq(), + expErr: daos.MiscError, + }, + "success": { + drpcResp: &mgmtpb.GroupUpdateResp{}, + expDrpcReq: getDefaultGroupUpdateReq(), + }, + "force": { + force: true, + drpcResp: &mgmtpb.GroupUpdateResp{}, + expDrpcReq: getGroupUpdateReq(defaultMemberCount+1, defaultMemberCount), + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + if tc.getSvc == nil { + tc.getSvc = func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := defaultTestMS(t, l) + return svc + } + } + svc := tc.getSvc(t, log) + mockDrpc := getMockDrpcClient(tc.drpcResp, tc.drpcErr) + setMockDrpcClient(svc, mockDrpc) + + err := svc.doGroupUpdate(test.Context(t), tc.force) + + test.CmpErr(t, tc.expErr, err) + + gotDrpcCalls := mockDrpc.calls.get() + if tc.expDrpcReq == nil { + test.AssertEqual(t, 0, len(gotDrpcCalls), "no dRPC calls expected") + } else { + test.AssertEqual(t, 1, len(gotDrpcCalls), "expected a GroupUpdate dRPC call") + + gotReq := new(mgmtpb.GroupUpdateReq) + if err := proto.Unmarshal(gotDrpcCalls[0].Body, gotReq); err != nil { + t.Fatal(err) + } + + // Order of engines in the actual req is arbitrary -- sort for comparison + sort.Slice(gotReq.Engines, func(i, j int) bool { + return gotReq.Engines[i].Rank < gotReq.Engines[j].Rank + }) + if diff := cmp.Diff(tc.expDrpcReq, gotReq, cmpopts.IgnoreUnexported(mgmtpb.GroupUpdateReq{}, mgmtpb.GroupUpdateReq_Engine{})); diff != "" { + t.Fatalf("want-, got+:\n%s", diff) + } + } + }) + } +} + func TestMgmtSvc_updateFabricProviders(t *testing.T) { for name, tc := range map[string]struct { - getSvc func(*testing.T, logging.Logger) *mgmtSvc - oldProv string - provs []string - expErr error - expProv string - expNumEvents int + getSvc func(*testing.T, logging.Logger) *mgmtSvc + oldProv string + provs []string + expErr error + expProv string + expNumEvents int + expGroupUpdatePaused bool }{ "no change": { oldProv: "tcp", @@ -2180,10 +2346,11 @@ func TestMgmtSvc_updateFabricProviders(t *testing.T) { expProv: "tcp", }, "successful change": { - oldProv: "tcp", - provs: []string{"verbs"}, - expProv: "verbs", - expNumEvents: 1, + oldProv: "tcp", + provs: []string{"verbs"}, + expProv: "verbs", + expNumEvents: 1, + expGroupUpdatePaused: true, }, "fails getting prop": { getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { @@ -2286,6 +2453,15 @@ func TestMgmtSvc_updateFabricProviders(t *testing.T) { test.AssertEqual(t, events.RASSystemFabricProvChanged, gotEvent.ID, "") test.AssertEqual(t, events.RASSeverityNotice, gotEvent.Severity, "") } + + if tc.expProv != "" { + curProv, err := svc.getFabricProvider() + if err != nil { + t.Fatal(err) + } + test.AssertEqual(t, tc.expProv, curProv, "") + } + test.AssertEqual(t, tc.expGroupUpdatePaused, svc.isGroupUpdatePaused(), "") }) } } @@ -2366,3 +2542,176 @@ func TestMgmtSvc_checkReqFabricProvider(t *testing.T) { }) } } + +func TestMgmtSvc_isGroupUpdatePaused(t *testing.T) { + for name, tc := range map[string]struct { + getSvc func(*testing.T, logging.Logger) *mgmtSvc + propVal string + expResult bool + }{ + "not leader": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + svc := newTestMgmtSvcMulti(t, log, maxEngines, false) + svc.sysdb = raft.MockDatabaseWithCfg(t, log, &raft.DatabaseConfig{ + SystemName: build.DefaultSystemName, + }) + + return svc + }, + }, + "never set": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + return mgmtSystemTestSetup(t, log, system.Members{}, []*control.HostResponse{}) + }, + }, + "empty string": {}, + "true": { + propVal: "true", + expResult: true, + }, + "true numeric": { + propVal: "1", + expResult: true, + }, + "false": { + propVal: "false", + }, + "false numeric": { + propVal: "0", + }, + "garbage": { + propVal: "blah blah blah", + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + if tc.getSvc == nil { + tc.getSvc = func(t *testing.T, l logging.Logger) *mgmtSvc { + ms := mgmtSystemTestSetup(t, l, system.Members{}, []*control.HostResponse{}) + if err := system.SetMgmtProperty(ms.sysdb, groupUpdatePauseProp, tc.propVal); err != nil { + t.Fatal(err) + } + return ms + } + } + + svc := tc.getSvc(t, log) + + test.AssertEqual(t, tc.expResult, svc.isGroupUpdatePaused(), "") + }) + } +} + +func TestMgmtSvc_pauseGroupUpdate(t *testing.T) { + for name, tc := range map[string]struct { + getSvc func(*testing.T, logging.Logger) *mgmtSvc + startVal string + expErr error + }{ + "not leader": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + svc := newTestMgmtSvcMulti(t, log, maxEngines, false) + svc.sysdb = raft.MockDatabaseWithCfg(t, log, &raft.DatabaseConfig{ + SystemName: build.DefaultSystemName, + }) + + return svc + }, + expErr: &system.ErrNotReplica{}, + }, + "never set": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + return mgmtSystemTestSetup(t, log, system.Members{}, []*control.HostResponse{}) + }, + }, + "true": { + startVal: "true", + }, + "false": { + startVal: "false", + }, + "garbage": { + startVal: "blah blah blah", + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + if tc.getSvc == nil { + tc.getSvc = func(t *testing.T, l logging.Logger) *mgmtSvc { + ms := mgmtSystemTestSetup(t, l, system.Members{}, []*control.HostResponse{}) + if err := system.SetMgmtProperty(ms.sysdb, groupUpdatePauseProp, tc.startVal); err != nil { + t.Fatal(err) + } + return ms + } + } + + svc := tc.getSvc(t, log) + + err := svc.pauseGroupUpdate() + + test.CmpErr(t, tc.expErr, err) + test.AssertEqual(t, tc.expErr == nil, svc.isGroupUpdatePaused(), "") + }) + } +} + +func TestMgmtSvc_resumeGroupUpdate(t *testing.T) { + for name, tc := range map[string]struct { + getSvc func(*testing.T, logging.Logger) *mgmtSvc + startVal string + expErr error + }{ + "not leader": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + svc := newTestMgmtSvcMulti(t, log, maxEngines, false) + svc.sysdb = raft.MockDatabaseWithCfg(t, log, &raft.DatabaseConfig{ + SystemName: build.DefaultSystemName, + }) + + return svc + }, + expErr: &system.ErrNotReplica{}, + }, + "never set": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + return mgmtSystemTestSetup(t, log, system.Members{}, []*control.HostResponse{}) + }, + }, + "true": { + startVal: "true", + }, + "false": { + startVal: "false", + }, + "garbage": { + startVal: "blah blah blah", + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + if tc.getSvc == nil { + tc.getSvc = func(t *testing.T, l logging.Logger) *mgmtSvc { + ms := mgmtSystemTestSetup(t, l, system.Members{}, []*control.HostResponse{}) + if err := system.SetMgmtProperty(ms.sysdb, groupUpdatePauseProp, tc.startVal); err != nil { + t.Fatal(err) + } + return ms + } + } + + svc := tc.getSvc(t, log) + + err := svc.resumeGroupUpdate() + + test.CmpErr(t, tc.expErr, err) + test.AssertFalse(t, svc.isGroupUpdatePaused(), "") + }) + } +} diff --git a/src/control/server/server.go b/src/control/server/server.go index 4a3efd655fc..fd353171a7d 100644 --- a/src/control/server/server.go +++ b/src/control/server/server.go @@ -442,6 +442,7 @@ func (srv *server) registerEvents() { if err := srv.mgmtSvc.updateFabricProviders([]string{srv.cfg.Fabric.Provider}, srv.pubSub); err != nil { srv.log.Errorf(err.Error()) + return err } srv.mgmtSvc.startLeaderLoops(ctx) diff --git a/src/control/server/util_test.go b/src/control/server/util_test.go index 67037a17253..f2d54c8e585 100644 --- a/src/control/server/util_test.go +++ b/src/control/server/util_test.go @@ -171,17 +171,29 @@ func newMockDrpcClient(cfg *mockDrpcClientConfig) *mockDrpcClient { // setupMockDrpcClientBytes sets up the dRPC client for the mgmtSvc to return // a set of bytes as a response. func setupMockDrpcClientBytes(svc *mgmtSvc, respBytes []byte, err error) { - mi := svc.harness.instances[0] + setMockDrpcClient(svc, getMockDrpcClientBytes(respBytes, err)) +} + +func getMockDrpcClientBytes(respBytes []byte, err error) *mockDrpcClient { cfg := &mockDrpcClientConfig{} cfg.setSendMsgResponse(drpc.Status_SUCCESS, respBytes, err) - mi.(*EngineInstance).setDrpcClient(newMockDrpcClient(cfg)) + return newMockDrpcClient(cfg) } // setupMockDrpcClient sets up the dRPC client for the mgmtSvc to return // a valid protobuf message as a response. func setupMockDrpcClient(svc *mgmtSvc, resp proto.Message, err error) { + setMockDrpcClient(svc, getMockDrpcClient(resp, err)) +} + +func getMockDrpcClient(resp proto.Message, err error) *mockDrpcClient { respBytes, _ := proto.Marshal(resp) - setupMockDrpcClientBytes(svc, respBytes, err) + return getMockDrpcClientBytes(respBytes, err) +} + +func setMockDrpcClient(svc *mgmtSvc, mdc *mockDrpcClient) { + mi := svc.harness.instances[0] + mi.(*EngineInstance).setDrpcClient(mdc) } // newTestEngine returns an EngineInstance configured for testing.