diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index b545f27d746..c2a08d6796d 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -818,6 +818,28 @@ paths: default: description: Default response + "/gsoc/subscribe/{address}": + get: + summary: Subscribe to GSOC payloads + tags: + - GSOC + - Subscribe + - Websocket + parameters: + - in: path + name: reference + schema: + $ref: "SwarmCommon.yaml#/components/schemas/SwarmReference" + required: true + description: "Single Owner Chunk address (which may have multiple payloads)" + responses: + "200": + description: Returns a WebSocket with a subscription for incoming message data on the requested SOC address. + "500": + $ref: "SwarmCommon.yaml#/components/responses/500" + default: + description: Default response + "/soc/{owner}/{id}": post: summary: Upload single owner chunk diff --git a/pkg/api/api.go b/pkg/api/api.go index 4568ad96bf5..bced503484c 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -33,6 +33,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/file/pipeline" "github.com/ethersphere/bee/v2/pkg/file/pipeline/builder" "github.com/ethersphere/bee/v2/pkg/file/redundancy" + "github.com/ethersphere/bee/v2/pkg/gsoc" "github.com/ethersphere/bee/v2/pkg/jsonhttp" "github.com/ethersphere/bee/v2/pkg/log" "github.com/ethersphere/bee/v2/pkg/p2p" @@ -151,6 +152,7 @@ type Service struct { storer Storer resolver resolver.Interface pss pss.Interface + gsoc gsoc.Listener steward steward.Interface logger log.Logger loggerV1 log.Logger @@ -253,6 +255,7 @@ type ExtraOptions struct { Storer Storer Resolver resolver.Interface Pss pss.Interface + Gsoc gsoc.Listener FeedFactory feeds.Factory Post postage.Service AccessControl accesscontrol.Controller @@ -336,6 +339,7 @@ func (s *Service) Configure(signer crypto.Signer, tracer *tracing.Tracer, o Opti s.storer = e.Storer s.resolver = e.Resolver s.pss = e.Pss + s.gsoc = e.Gsoc s.feedFactory = e.FeedFactory s.post = e.Post s.accesscontrol = e.AccessControl diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index 7bd20c07e87..dc78dd64077 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -32,6 +32,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/file/pipeline" "github.com/ethersphere/bee/v2/pkg/file/pipeline/builder" "github.com/ethersphere/bee/v2/pkg/file/redundancy" + "github.com/ethersphere/bee/v2/pkg/gsoc" "github.com/ethersphere/bee/v2/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/v2/pkg/log" p2pmock "github.com/ethersphere/bee/v2/pkg/p2p/mock" @@ -93,6 +94,7 @@ type testServerOptions struct { StateStorer storage.StateStorer Resolver resolver.Interface Pss pss.Interface + Gsoc gsoc.Listener WsPath string WsPingPeriod time.Duration Logger log.Logger @@ -191,6 +193,7 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket. Storer: o.Storer, Resolver: o.Resolver, Pss: o.Pss, + Gsoc: o.Gsoc, FeedFactory: o.Feeds, Post: o.Post, AccessControl: o.AccessControl, diff --git a/pkg/api/gsoc.go b/pkg/api/gsoc.go new file mode 100644 index 00000000000..ea9aad5271e --- /dev/null +++ b/pkg/api/gsoc.go @@ -0,0 +1,119 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api + +import ( + "net/http" + "time" + + "github.com/ethersphere/bee/v2/pkg/jsonhttp" + "github.com/ethersphere/bee/v2/pkg/swarm" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" +) + +func (s *Service) gsocWsHandler(w http.ResponseWriter, r *http.Request) { + logger := s.logger.WithName("gsoc_subscribe").Build() + + paths := struct { + Address []byte `map:"address" validate:"required"` + }{} + if response := s.mapStructure(mux.Vars(r), &paths); response != nil { + response("invalid path params", logger, w) + return + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: swarm.ChunkSize, + WriteBufferSize: swarm.ChunkSize, + CheckOrigin: s.checkOrigin, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Debug("upgrade failed", "error", err) + logger.Error(nil, "upgrade failed") + jsonhttp.InternalServerError(w, "upgrade failed") + return + } + + s.wsWg.Add(1) + go s.gsocListeningWs(conn, paths.Address) +} + +func (s *Service) gsocListeningWs(conn *websocket.Conn, socAddress []byte) { + defer s.wsWg.Done() + + var ( + dataC = make(chan []byte) + gone = make(chan struct{}) + ticker = time.NewTicker(s.WsPingPeriod) + err error + ) + defer func() { + ticker.Stop() + _ = conn.Close() + }() + cleanup := s.gsoc.Subscribe([32]byte(socAddress), func(m []byte) { + select { + case dataC <- m: + case <-gone: + return + case <-s.quit: + return + } + }) + + defer cleanup() + + conn.SetCloseHandler(func(code int, text string) error { + s.logger.Debug("gsoc ws: client gone", "code", code, "message", text) + close(gone) + return nil + }) + + for { + select { + case b := <-dataC: + err = conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if err != nil { + s.logger.Debug("gsoc ws: set write deadline failed", "error", err) + return + } + + err = conn.WriteMessage(websocket.BinaryMessage, b) + if err != nil { + s.logger.Debug("gsoc ws: write message failed", "error", err) + return + } + + case <-s.quit: + // shutdown + err = conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if err != nil { + s.logger.Debug("gsoc ws: set write deadline failed", "error", err) + return + } + err = conn.WriteMessage(websocket.CloseMessage, []byte{}) + if err != nil { + s.logger.Debug("gsoc ws: write close message failed", "error", err) + } + return + case <-gone: + // client gone + return + case <-ticker.C: + err = conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + if err != nil { + s.logger.Debug("gsoc ws: set write deadline failed", "error", err) + return + } + if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil { + // error encountered while pinging client. client probably gone + return + } + } + } +} diff --git a/pkg/api/gsoc_test.go b/pkg/api/gsoc_test.go new file mode 100644 index 00000000000..5d0a70cc9c4 --- /dev/null +++ b/pkg/api/gsoc_test.go @@ -0,0 +1,171 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api_test + +import ( + "encoding/hex" + "fmt" + "net/url" + "strings" + "testing" + "time" + + "github.com/ethersphere/bee/v2/pkg/cac" + "github.com/ethersphere/bee/v2/pkg/crypto" + "github.com/ethersphere/bee/v2/pkg/gsoc" + "github.com/ethersphere/bee/v2/pkg/log" + mockbatchstore "github.com/ethersphere/bee/v2/pkg/postage/batchstore/mock" + "github.com/ethersphere/bee/v2/pkg/soc" + mockstorer "github.com/ethersphere/bee/v2/pkg/storer/mock" + "github.com/ethersphere/bee/v2/pkg/swarm" + "github.com/ethersphere/bee/v2/pkg/util/testutil" + "github.com/gorilla/websocket" +) + +// TestGsocWebsocketSingleHandler creates a single websocket handler on a chunk address, and receives a message +func TestGsocWebsocketSingleHandler(t *testing.T) { + t.Parallel() + + var ( + id = make([]byte, 32) + g, cl, signer, _ = newGsocTest(t, id, 0) + respC = make(chan error, 1) + payload = []byte("hello there!") + ) + + err := cl.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err != nil { + t.Fatal(err) + } + cl.SetReadLimit(swarm.ChunkSize) + + ch, _ := cac.New(payload) + socCh := soc.New(id, ch) + ch, _ = socCh.Sign(signer) + socCh, _ = soc.FromChunk(ch) + g.Handle(*socCh) + + go expectMessage(t, cl, respC, payload) + if err := <-respC; err != nil { + t.Fatal(err) + } +} + +func TestGsocWebsocketMultiHandler(t *testing.T) { + t.Parallel() + + var ( + id = make([]byte, 32) + g, cl, signer, listener = newGsocTest(t, make([]byte, 32), 0) + owner, _ = signer.EthereumAddress() + chunkAddr, _ = soc.CreateAddress(id, owner.Bytes()) + u = url.URL{Scheme: "ws", Host: listener, Path: fmt.Sprintf("/gsoc/subscribe/%s", hex.EncodeToString(chunkAddr.Bytes()))} + cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil) + respC = make(chan error, 2) + ) + if err != nil { + t.Fatalf("dial: %v. url %v", err, u.String()) + } + testutil.CleanupCloser(t, cl2) + + err = cl.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err != nil { + t.Fatal(err) + } + cl.SetReadLimit(swarm.ChunkSize) + + ch, _ := cac.New(payload) + socCh := soc.New(id, ch) + ch, _ = socCh.Sign(signer) + socCh, _ = soc.FromChunk(ch) + + // close the websocket before calling GSOC with the message + err = cl.WriteMessage(websocket.CloseMessage, []byte{}) + if err != nil { + t.Fatal(err) + } + + g.Handle(*socCh) + + go expectMessage(t, cl, respC, payload) + go expectMessage(t, cl2, respC, payload) + if err := <-respC; err != nil { + t.Fatal(err) + } + if err := <-respC; err != nil { + t.Fatal(err) + } +} + +// TestGsocPong tests that the websocket api adheres to the websocket standard +// and sends ping-pong messages to keep the connection alive. +// The test opens a websocket, keeps it alive for 500ms, then receives a GSOC message. +func TestGsocPong(t *testing.T) { + t.Parallel() + id := make([]byte, 32) + + var ( + g, cl, signer, _ = newGsocTest(t, id, 90*time.Millisecond) + + respC = make(chan error, 1) + pongWait = 1 * time.Millisecond + ) + + cl.SetReadLimit(swarm.ChunkSize) + err := cl.SetReadDeadline(time.Now().Add(pongWait)) + if err != nil { + t.Fatal(err) + } + + time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive + ch, _ := cac.New([]byte("hello there!")) + socCh := soc.New(id, ch) + ch, _ = socCh.Sign(signer) + socCh, _ = soc.FromChunk(ch) + + g.Handle(*socCh) + + go expectMessage(t, cl, respC, nil) + if err := <-respC; err == nil || !strings.Contains(err.Error(), "i/o timeout") { + // note: error has *websocket.netError type so we need to check error by checking message + t.Fatal("want timeout error") + } +} + +func newGsocTest(t *testing.T, socId []byte, pingPeriod time.Duration) (gsoc.Listener, *websocket.Conn, crypto.Signer, string) { + t.Helper() + if pingPeriod == 0 { + pingPeriod = 10 * time.Second + } + var ( + batchStore = mockbatchstore.New() + storer = mockstorer.New() + ) + + privKey, err := crypto.GenerateSecp256k1Key() + if err != nil { + t.Fatal(err) + } + signer := crypto.NewDefaultSigner(privKey) + owner, err := signer.EthereumAddress() + if err != nil { + t.Fatal(err) + } + chunkAddr, _ := soc.CreateAddress(socId, owner.Bytes()) + + gsoc := gsoc.New(log.NewLogger("test")) + testutil.CleanupCloser(t, gsoc) + + _, cl, listener, _ := newTestServer(t, testServerOptions{ + Gsoc: gsoc, + WsPath: fmt.Sprintf("/gsoc/subscribe/%s", hex.EncodeToString(chunkAddr.Bytes())), + Storer: storer, + BatchStore: batchStore, + Logger: log.Noop, + WsPingPeriod: pingPeriod, + }) + + return gsoc, cl, signer, listener +} diff --git a/pkg/api/router.go b/pkg/api/router.go index d7511c1c52e..aa1f1514622 100644 --- a/pkg/api/router.go +++ b/pkg/api/router.go @@ -321,6 +321,10 @@ func (s *Service) mountAPI() { web.FinalHandlerFunc(s.pssWsHandler), )) + handle("/gsoc/subscribe/{address}", web.ChainHandlers( + web.FinalHandlerFunc(s.gsocWsHandler), + )) + handle("/tags", web.ChainHandlers( web.FinalHandler(jsonhttp.MethodHandler{ "GET": http.HandlerFunc(s.listTagsHandler), diff --git a/pkg/api/soc.go b/pkg/api/soc.go index 09d4f1d3ae2..a6e0c37e187 100644 --- a/pkg/api/soc.go +++ b/pkg/api/soc.go @@ -124,6 +124,8 @@ func (s *Service) socUploadHandler(w http.ResponseWriter, r *http.Request) { jsonhttp.NotFound(w, "batch with id not found") case errors.Is(err, errInvalidPostageBatch): jsonhttp.BadRequest(w, "invalid batch id") + case errors.Is(err, errUnsupportedDevNodeOperation): + jsonhttp.NotImplemented(w, "operation is not supported in dev mode") default: jsonhttp.BadRequest(w, nil) } diff --git a/pkg/gsoc/gsoc.go b/pkg/gsoc/gsoc.go new file mode 100644 index 00000000000..c2f88fc551c --- /dev/null +++ b/pkg/gsoc/gsoc.go @@ -0,0 +1,100 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gsoc + +import ( + "sync" + + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/soc" + "github.com/ethersphere/bee/v2/pkg/swarm" +) + +type Listener interface { + Subscribe(address [32]byte, handler handler) (cleanup func()) + Handle(c soc.SOC) + Close() error +} + +type listener struct { + handlers map[[32]byte][]*handler + handlersMu sync.Mutex + quit chan struct{} + logger log.Logger +} + +// New returns a new pss service. +func New(logger log.Logger) Listener { + return &listener{ + logger: logger, + handlers: make(map[[32]byte][]*handler), + quit: make(chan struct{}), + } +} + +// Subscribe allows the definition of a Handler func for a specific topic on the pss struct. +func (l *listener) Subscribe(address [32]byte, handler handler) (cleanup func()) { + l.handlersMu.Lock() + defer l.handlersMu.Unlock() + + l.handlers[address] = append(l.handlers[address], &handler) + + return func() { + l.handlersMu.Lock() + defer l.handlersMu.Unlock() + + h := l.handlers[address] + for i := 0; i < len(h); i++ { + if h[i] == &handler { + l.handlers[address] = append(h[:i], h[i+1:]...) + return + } + } + } +} + +// Handle is called by push/pull sync and passes the chunk its registered handler +func (l *listener) Handle(c soc.SOC) { + addr, err := c.Address() + if err != nil { + return // no handler + } + h := l.getHandlers([32]byte(addr.Bytes())) + if h == nil { + return // no handler + } + l.logger.Info("new incoming GSOC message", + "GSOC Address", addr, + "wrapped chunk address", c.WrappedChunk().Address()) + + var wg sync.WaitGroup + for _, hh := range h { + wg.Add(1) + go func(hh handler) { + defer wg.Done() + hh(c.WrappedChunk().Data()[swarm.SpanSize:]) + }(*hh) + } +} + +func (p *listener) getHandlers(address [32]byte) []*handler { + p.handlersMu.Lock() + defer p.handlersMu.Unlock() + + return p.handlers[address] +} + +func (l *listener) Close() error { + close(l.quit) + l.handlersMu.Lock() + defer l.handlersMu.Unlock() + + l.handlers = make(map[[32]byte][]*handler) //unset handlers on shutdown + + return nil +} + +// handler defines code to be executed upon reception of a GSOC sub message. +type handler func([]byte) diff --git a/pkg/gsoc/gsoc_test.go b/pkg/gsoc/gsoc_test.go new file mode 100644 index 00000000000..0bab1b39bfb --- /dev/null +++ b/pkg/gsoc/gsoc_test.go @@ -0,0 +1,124 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gsoc_test + +import ( + "testing" + "time" + + "github.com/ethersphere/bee/v2/pkg/cac" + "github.com/ethersphere/bee/v2/pkg/crypto" + "github.com/ethersphere/bee/v2/pkg/gsoc" + "github.com/ethersphere/bee/v2/pkg/log" + "github.com/ethersphere/bee/v2/pkg/soc" + "github.com/ethersphere/bee/v2/pkg/util/testutil" +) + +// TestRegister verifies that handler funcs are able to be registered correctly in pss +func TestRegister(t *testing.T) { + t.Parallel() + + var ( + g = gsoc.New(log.NewLogger("test")) + h1Calls = 0 + h2Calls = 0 + h3Calls = 0 + msgChan = make(chan struct{}) + + payload1 = []byte("Hello there!") + payload2 = []byte("General Kenobi. You are a bold one. Kill him!") + socId1 = testutil.RandBytes(t, 32) + socId2 = append([]byte{socId1[0] + 1}, socId1[1:]...) + privKey, _ = crypto.GenerateSecp256k1Key() + signer = crypto.NewDefaultSigner(privKey) + owner, _ = signer.EthereumAddress() + address1, _ = soc.CreateAddress(socId1, owner.Bytes()) + address2, _ = soc.CreateAddress(socId2, owner.Bytes()) + + h1 = func(m []byte) { + h1Calls++ + msgChan <- struct{}{} + } + + h2 = func(m []byte) { + h2Calls++ + msgChan <- struct{}{} + } + + h3 = func(m []byte) { + h3Calls++ + msgChan <- struct{}{} + } + ) + _ = g.Subscribe([32]byte(address1.Bytes()), h1) + _ = g.Subscribe([32]byte(address2.Bytes()), h2) + + ch1, _ := cac.New(payload1) + socCh1 := soc.New(socId1, ch1) + ch1, _ = socCh1.Sign(signer) + socCh1, _ = soc.FromChunk(ch1) + + ch2, _ := cac.New(payload2) + socCh2 := soc.New(socId2, ch2) + ch2, _ = socCh2.Sign(signer) + socCh2, _ = soc.FromChunk(ch2) + + // trigger soc upload on address1, check that only h1 is called + g.Handle(*socCh1) + + waitHandlerCallback(t, &msgChan, 1) + + ensureCalls(t, &h1Calls, 1) + ensureCalls(t, &h2Calls, 0) + + // register another handler on the first address + cleanup := g.Subscribe([32]byte(address1.Bytes()), h3) + + g.Handle(*socCh1) + + waitHandlerCallback(t, &msgChan, 2) + + ensureCalls(t, &h1Calls, 2) + ensureCalls(t, &h2Calls, 0) + ensureCalls(t, &h3Calls, 1) + + cleanup() // remove the last handler + + g.Handle(*socCh1) + + waitHandlerCallback(t, &msgChan, 1) + + ensureCalls(t, &h1Calls, 3) + ensureCalls(t, &h2Calls, 0) + ensureCalls(t, &h3Calls, 1) + + g.Handle(*socCh2) + + waitHandlerCallback(t, &msgChan, 1) + + ensureCalls(t, &h1Calls, 3) + ensureCalls(t, &h2Calls, 1) + ensureCalls(t, &h3Calls, 1) +} + +func ensureCalls(t *testing.T, calls *int, exp int) { + t.Helper() + + if exp != *calls { + t.Fatalf("expected %d calls, found %d", exp, *calls) + } +} + +func waitHandlerCallback(t *testing.T, msgChan *chan struct{}, count int) { + t.Helper() + + for received := 0; received < count; received++ { + select { + case <-*msgChan: + case <-time.After(1 * time.Second): + t.Fatal("reached timeout while waiting for handler message") + } + } +} diff --git a/pkg/node/devnode.go b/pkg/node/devnode.go index 5439fe2d2a7..68f349327ed 100644 --- a/pkg/node/devnode.go +++ b/pkg/node/devnode.go @@ -23,6 +23,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/bzz" "github.com/ethersphere/bee/v2/pkg/crypto" "github.com/ethersphere/bee/v2/pkg/feeds/factory" + "github.com/ethersphere/bee/v2/pkg/gsoc" "github.com/ethersphere/bee/v2/pkg/log" mockP2P "github.com/ethersphere/bee/v2/pkg/p2p/mock" mockPingPong "github.com/ethersphere/bee/v2/pkg/pingpong/mock" @@ -342,6 +343,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { Storer: localStore, Resolver: mockResolver, Pss: pssService, + Gsoc: gsoc.New(logger), FeedFactory: mockFeeds, Post: post, AccessControl: accesscontrol, diff --git a/pkg/node/node.go b/pkg/node/node.go index 1fad6195e26..ec855d55fce 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -32,6 +32,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/config" "github.com/ethersphere/bee/v2/pkg/crypto" "github.com/ethersphere/bee/v2/pkg/feeds/factory" + "github.com/ethersphere/bee/v2/pkg/gsoc" "github.com/ethersphere/bee/v2/pkg/hive" "github.com/ethersphere/bee/v2/pkg/log" "github.com/ethersphere/bee/v2/pkg/metrics" @@ -897,6 +898,7 @@ func NewBee( pricing.SetPaymentThresholdObserver(acc) pssService := pss.New(pssPrivateKey, logger) + gsocService := gsoc.New(logger) b.pssCloser = pssService validStamp := postage.ValidStamp(batchStore) @@ -950,7 +952,7 @@ func NewBee( } } - pushSyncProtocol := pushsync.New(swarmAddress, networkID, nonce, p2ps, localStore, waitNetworkRFunc, kad, o.FullNodeMode && !o.BootnodeMode, pssService.TryUnwrap, validStamp, logger, acc, pricer, signer, tracer, warmupTime) + pushSyncProtocol := pushsync.New(swarmAddress, networkID, nonce, p2ps, localStore, waitNetworkRFunc, kad, o.FullNodeMode && !o.BootnodeMode, pssService.TryUnwrap, gsocService.Handle, validStamp, logger, acc, pricer, signer, tracer, warmupTime) b.pushSyncCloser = pushSyncProtocol // set the pushSyncer in the PSS @@ -964,7 +966,7 @@ func NewBee( pusherService.AddFeed(localStore.PusherFeed()) - pullSyncProtocol := pullsync.New(p2ps, localStore, pssService.TryUnwrap, validStamp, logger, pullsync.DefaultMaxPage) + pullSyncProtocol := pullsync.New(p2ps, localStore, pssService.TryUnwrap, gsocService.Handle, validStamp, logger, pullsync.DefaultMaxPage) b.pullSyncCloser = pullSyncProtocol retrieveProtocolSpec := retrieval.Protocol() @@ -1122,6 +1124,7 @@ func NewBee( Storer: localStore, Resolver: multiResolver, Pss: pssService, + Gsoc: gsocService, FeedFactory: feedFactory, Post: post, AccessControl: accesscontrol, diff --git a/pkg/p2p/streamtest/streamtest.go b/pkg/p2p/streamtest/streamtest.go index a9892687240..ae312624149 100644 --- a/pkg/p2p/streamtest/streamtest.go +++ b/pkg/p2p/streamtest/streamtest.go @@ -96,6 +96,13 @@ func New(opts ...Option) *Recorder { return r } +func (r *Recorder) Reset() { + r.recordsMu.Lock() + defer r.recordsMu.Unlock() + + r.records = make(map[string][]*Record) +} + func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) { r.protocols = append(r.protocols, protocols...) } diff --git a/pkg/pullsync/pullsync.go b/pkg/pullsync/pullsync.go index afd2ee17fed..a917169e0d6 100644 --- a/pkg/pullsync/pullsync.go +++ b/pkg/pullsync/pullsync.go @@ -71,6 +71,7 @@ type Syncer struct { store storer.Reserve quit chan struct{} unwrap func(swarm.Chunk) + gsocHandler func(soc.SOC) validStamp postage.ValidStampFn intervalsSF singleflight.Group[string, *collectAddrsResult] syncInProgress atomic.Int32 @@ -87,21 +88,23 @@ func New( streamer p2p.Streamer, store storer.Reserve, unwrap func(swarm.Chunk), + gsocHandler func(soc.SOC), validStamp postage.ValidStampFn, logger log.Logger, maxPage uint64, ) *Syncer { return &Syncer{ - streamer: streamer, - store: store, - metrics: newMetrics(), - unwrap: unwrap, - validStamp: validStamp, - logger: logger.WithName(loggerName).Register(), - quit: make(chan struct{}), - maxPage: maxPage, - limiter: ratelimit.New(handleRequestsLimitRate, int(maxPage)), + streamer: streamer, + store: store, + metrics: newMetrics(), + unwrap: unwrap, + gsocHandler: gsocHandler, + validStamp: validStamp, + logger: logger.WithName(loggerName).Register(), + quit: make(chan struct{}), + maxPage: maxPage, + limiter: ratelimit.New(handleRequestsLimitRate, int(maxPage)), } } @@ -356,7 +359,9 @@ func (s *Syncer) Sync(ctx context.Context, peer swarm.Address, bin uint8, start if cac.Valid(chunk) { go s.unwrap(chunk) - } else if !soc.Valid(chunk) { + } else if chunk, err := soc.FromChunk(chunk); err == nil { + s.gsocHandler(*chunk) + } else { s.logger.Debug("invalid cac/soc chunk", "error", swarm.ErrInvalidChunk, "peer_address", peer, "chunk", chunk) chunkErr = errors.Join(chunkErr, swarm.ErrInvalidChunk) s.metrics.ReceivedInvalidChunk.Inc() diff --git a/pkg/pullsync/pullsync_test.go b/pkg/pullsync/pullsync_test.go index e77f54705ed..68d9e04ecbc 100644 --- a/pkg/pullsync/pullsync_test.go +++ b/pkg/pullsync/pullsync_test.go @@ -17,6 +17,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/postage" postagetesting "github.com/ethersphere/bee/v2/pkg/postage/testing" "github.com/ethersphere/bee/v2/pkg/pullsync" + "github.com/ethersphere/bee/v2/pkg/soc" "github.com/ethersphere/bee/v2/pkg/storage" testingc "github.com/ethersphere/bee/v2/pkg/storage/testing" "github.com/ethersphere/bee/v2/pkg/storer" @@ -353,10 +354,12 @@ func newPullSyncWithStamperValidator( storage := mock.NewReserve(o...) logger := log.Noop unwrap := func(swarm.Chunk) {} + socHandler := func(soc.SOC) {} ps := pullsync.New( s, storage, unwrap, + socHandler, validStamp, logger, maxPage, diff --git a/pkg/pusher/inflight.go b/pkg/pusher/inflight.go index 788872d2652..99ec53c96ff 100644 --- a/pkg/pusher/inflight.go +++ b/pkg/pusher/inflight.go @@ -12,28 +12,33 @@ import ( type inflight struct { mtx sync.Mutex - inflight map[string]struct{} + inflight map[[64]byte]struct{} } func newInflight() *inflight { return &inflight{ - inflight: make(map[string]struct{}), + inflight: make(map[[64]byte]struct{}), } } -func (i *inflight) delete(ch swarm.Chunk) { - key := ch.Address().ByteString() + string(ch.Stamp().BatchID()) +func (i *inflight) delete(idAddress swarm.Address, batchID []byte) { + var key [64]byte + copy(key[:32], idAddress.Bytes()) + copy(key[32:], batchID) + i.mtx.Lock() delete(i.inflight, key) i.mtx.Unlock() } -func (i *inflight) set(ch swarm.Chunk) bool { +func (i *inflight) set(idAddress swarm.Address, batchID []byte) bool { + var key [64]byte + copy(key[:32], idAddress.Bytes()) + copy(key[32:], batchID) i.mtx.Lock() defer i.mtx.Unlock() - key := ch.Address().ByteString() + string(ch.Stamp().BatchID()) if _, ok := i.inflight[key]; ok { return true } @@ -50,16 +55,16 @@ type attempts struct { // try to log a chunk sync attempt. returns false when // maximum amount of attempts have been reached. -func (a *attempts) try(ch swarm.Address) bool { +func (a *attempts) try(idAddress swarm.Address) bool { a.mtx.Lock() defer a.mtx.Unlock() - key := ch.ByteString() + key := idAddress.ByteString() a.attempts[key]++ return a.attempts[key] < a.retryCount } -func (a *attempts) delete(ch swarm.Address) { +func (a *attempts) delete(idAddress swarm.Address) { a.mtx.Lock() - delete(a.attempts, ch.ByteString()) + delete(a.attempts, idAddress.ByteString()) a.mtx.Unlock() } diff --git a/pkg/pusher/pusher.go b/pkg/pusher/pusher.go index 7defcfb37cd..0945bdee05e 100644 --- a/pkg/pusher/pusher.go +++ b/pkg/pusher/pusher.go @@ -18,6 +18,7 @@ import ( "github.com/ethersphere/bee/v2/pkg/log" "github.com/ethersphere/bee/v2/pkg/postage" "github.com/ethersphere/bee/v2/pkg/pushsync" + "github.com/ethersphere/bee/v2/pkg/soc" storage "github.com/ethersphere/bee/v2/pkg/storage" "github.com/ethersphere/bee/v2/pkg/swarm" "github.com/ethersphere/bee/v2/pkg/topology" @@ -214,7 +215,11 @@ func (s *Service) chunksWorker(warmupTime time.Duration) { for { select { case op := <-cc: - if s.inflight.set(op.Chunk) { + idAddress, err := soc.IdentityAddress(op.Chunk) + if err != nil { + op.Err <- err + } + if s.inflight.set(idAddress, op.Chunk.Stamp().BatchID()) { if op.Direct { select { case op.Err <- nil: @@ -240,8 +245,12 @@ func (s *Service) chunksWorker(warmupTime time.Duration) { func (s *Service) pushDeferred(ctx context.Context, logger log.Logger, op *Op) (bool, error) { loggerV1 := logger.V(1).Build() + idAddress, err := soc.IdentityAddress(op.Chunk) + if err != nil { + return true, err + } - defer s.inflight.delete(op.Chunk) + defer s.inflight.delete(idAddress, op.Chunk.Stamp().BatchID()) if _, err := s.validStamp(op.Chunk); err != nil { loggerV1.Warning( @@ -254,7 +263,7 @@ func (s *Service) pushDeferred(ctx context.Context, logger log.Logger, op *Op) ( return false, errors.Join(err, s.storer.Report(ctx, op.Chunk, storage.ChunkCouldNotSync)) } - switch receipt, err := s.pushSyncer.PushChunkToClosest(ctx, op.Chunk); { + switch _, err := s.pushSyncer.PushChunkToClosest(ctx, op.Chunk); { case errors.Is(err, topology.ErrWantSelf): // store the chunk loggerV1.Debug("chunk stays here, i'm the closest node", "chunk_address", op.Chunk.Address()) @@ -269,7 +278,7 @@ func (s *Service) pushDeferred(ctx context.Context, logger log.Logger, op *Op) ( return true, err } case errors.Is(err, pushsync.ErrShallowReceipt): - if retry := s.shallowReceipt(receipt); retry { + if retry := s.shallowReceipt(idAddress); retry { return true, err } if err := s.storer.Report(ctx, op.Chunk, storage.ChunkSynced); err != nil { @@ -291,11 +300,13 @@ func (s *Service) pushDeferred(ctx context.Context, logger log.Logger, op *Op) ( func (s *Service) pushDirect(ctx context.Context, logger log.Logger, op *Op) error { loggerV1 := logger.V(1).Build() - - var err error + idAddress, err := soc.IdentityAddress(op.Chunk) + if err != nil { + return err + } defer func() { - s.inflight.delete(op.Chunk) + s.inflight.delete(idAddress, op.Chunk.Stamp().BatchID()) select { case op.Err <- err: default: @@ -329,11 +340,11 @@ func (s *Service) pushDirect(ctx context.Context, logger log.Logger, op *Op) err return err } -func (s *Service) shallowReceipt(receipt *pushsync.Receipt) bool { - if s.attempts.try(receipt.Address) { +func (s *Service) shallowReceipt(idAddress swarm.Address) bool { + if s.attempts.try(idAddress) { return true } - s.attempts.delete(receipt.Address) + s.attempts.delete(idAddress) return false } diff --git a/pkg/pushsync/pushsync.go b/pkg/pushsync/pushsync.go index c687a544727..22348100dae 100644 --- a/pkg/pushsync/pushsync.go +++ b/pkg/pushsync/pushsync.go @@ -85,6 +85,7 @@ type PushSync struct { store Storer topologyDriver topology.Driver unwrap func(swarm.Chunk) + gsocHandler func(soc.SOC) logger log.Logger accounting accounting.Interface pricer pricer.Interface @@ -114,6 +115,7 @@ func New( topology topology.Driver, fullNode bool, unwrap func(swarm.Chunk), + gsocHandler func(soc.SOC), validStamp postage.ValidStampFn, logger log.Logger, accounting accounting.Interface, @@ -132,6 +134,7 @@ func New( topologyDriver: topology, fullNode: fullNode, unwrap: unwrap, + gsocHandler: gsocHandler, logger: logger.WithName(loggerName).Register(), accounting: accounting, pricer: pricer, @@ -225,7 +228,9 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) if cac.Valid(chunk) { go ps.unwrap(chunk) - } else if !soc.Valid(chunk) { + } else if chunk, err := soc.FromChunk(chunk); err == nil { + ps.gsocHandler(*chunk) + } else { return swarm.ErrInvalidChunk } @@ -319,7 +324,6 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re Nonce: r.Nonce, }, err } - if err != nil { return nil, err } @@ -357,6 +361,11 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo sentErrorsLeft = maxPushErrors } + idAddress, err := soc.IdentityAddress(ch) + if err != nil { + return nil, err + } + resultChan := make(chan receiptResult) retryC := make(chan struct{}, max(1, parallelForwards)) @@ -393,10 +402,10 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo // If no peer can be found from an origin peer, the origin peer may store the chunk. // Non-origin peers store the chunk if the chunk is within depth. // For non-origin peers, if the chunk is not within depth, they may store the chunk if they are the closest peer to the chunk. - fullSkip := append(skip.ChunkPeers(ch.Address()), ps.errSkip.ChunkPeers(ch.Address())...) + fullSkip := append(skip.ChunkPeers(idAddress), ps.errSkip.ChunkPeers(idAddress)...) peer, err := ps.closestPeer(ch.Address(), origin, fullSkip) if errors.Is(err, topology.ErrNotFound) { - if skip.PruneExpiresAfter(ch.Address(), overDraftRefresh) == 0 { //no overdraft peers, we have depleted ALL peers + if skip.PruneExpiresAfter(idAddress, overDraftRefresh) == 0 { //no overdraft peers, we have depleted ALL peers if inflight == 0 { if ps.fullNode { if cac.Valid(ch) { @@ -433,7 +442,7 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo // all future requests should land directly into the neighborhood if neighborsOnly && peerPO < rad { - skip.Forever(ch.Address(), peer) + skip.Forever(idAddress, peer) continue } @@ -450,10 +459,10 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo action, err := ps.prepareCredit(ctx, peer, ch, origin) if err != nil { retry() - skip.Add(ch.Address(), peer, overDraftRefresh) + skip.Add(idAddress, peer, overDraftRefresh) continue } - skip.Forever(ch.Address(), peer) + skip.Forever(idAddress, peer) ps.metrics.TotalSendAttempts.Inc() inflight++ @@ -461,7 +470,6 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo go ps.push(ctx, resultChan, peer, ch, action) case result := <-resultChan: - inflight-- ps.measurePushPeer(result.pushTime, result.err) @@ -471,16 +479,16 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk, origin bo case err == nil: return result.receipt, nil case errors.Is(err, ErrShallowReceipt): - ps.errSkip.Add(ch.Address(), result.peer, skiplistDur) + ps.errSkip.Add(idAddress, result.peer, skiplistDur) return result.receipt, err } } ps.metrics.TotalFailedSendAttempts.Inc() - ps.logger.Debug("could not push to peer", "chunk_address", ch.Address(), "peer_address", result.peer, "error", result.err) + ps.logger.Debug("could not push to peer", "chunk_address", ch.Address(), "id_address", idAddress, "peer_address", result.peer, "error", result.err) sentErrorsLeft-- - ps.errSkip.Add(ch.Address(), result.peer, skiplistDur) + ps.errSkip.Add(idAddress, result.peer, skiplistDur) retry() } diff --git a/pkg/pushsync/pushsync_test.go b/pkg/pushsync/pushsync_test.go index 9b1ee648d3f..73296496753 100644 --- a/pkg/pushsync/pushsync_test.go +++ b/pkg/pushsync/pushsync_test.go @@ -25,6 +25,7 @@ import ( pricermock "github.com/ethersphere/bee/v2/pkg/pricer/mock" "github.com/ethersphere/bee/v2/pkg/pushsync" "github.com/ethersphere/bee/v2/pkg/pushsync/pb" + "github.com/ethersphere/bee/v2/pkg/soc" storage "github.com/ethersphere/bee/v2/pkg/storage" testingc "github.com/ethersphere/bee/v2/pkg/storage/testing" "github.com/ethersphere/bee/v2/pkg/swarm" @@ -110,6 +111,96 @@ func TestPushClosest(t *testing.T) { } } +// TestSocListener listens all payload of a SOC. This triggers sending a chunk to the closest node +// and expects a receipt. The message is intercepted in the outgoing stream to check for correctness. +func TestSocListener(t *testing.T) { + t.Parallel() + defaultSigner := cryptomock.New(cryptomock.WithSignFunc(func(addr []byte) ([]byte, error) { + key, _ := crypto.GenerateSecp256k1Key() + signer := crypto.NewDefaultSigner(key) + signature, _ := signer.Sign(addr) + + return signature, nil + })) + + // chunk data to upload + privKey, err := crypto.DecodeSecp256k1PrivateKey(swarm.MustParseHexAddress("b0baf37700000000000000000000000000000000000000000000000000000000").Bytes()) + if err != nil { + t.Fatal(err) + } + signer := crypto.NewDefaultSigner(privKey) + chunk1 := testingc.FixtureChunk("7000") + chunk2 := testingc.FixtureChunk("0033") + id := make([]byte, swarm.HashSize) + s1 := soc.New(id, chunk1) + s2 := soc.New(id, chunk2) + sch1, err := s1.Sign(signer) + if err != nil { + t.Fatal(err) + } + sch1 = sch1.WithStamp(chunk1.Stamp()) + sch2, err := s2.Sign(signer) + if err != nil { + t.Fatal(err) + } + sch2 = sch2.WithStamp(chunk2.Stamp()) + expectedPayload := chunk1.Data() + gsocListener := func(soc soc.SOC) { + if !bytes.Equal(soc.WrappedChunk().Data(), expectedPayload) { + t.Fatalf("unexpected SOC payload on GSOC listener. got %s, want %s", soc.WrappedChunk().Data(), expectedPayload) + } + } + + // create a pivot node and a mocked closest node + pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000 + closestPeer := swarm.MustParseHexAddress("8000000000000000000000000000000000000000000000000000000000000000") // binary 1000 -> po 1 + + // peer is the node responding to the chunk receipt message + // mock should return ErrWantSelf since there's no one to forward to + psPeer, _, _ := createGsocPushSyncNode(t, closestPeer, defaultPrices, nil, nil, defaultSigner, mock.WithClosestPeerErr(topology.ErrWantSelf)) + + recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()), streamtest.WithBaseAddr(pivotNode)) + + // pivot node needs the streamer since the chunk is intercepted by + // the chunk worker, then gets sent by opening a new stream + psPivot, _, _ := createGsocPushSyncNode(t, pivotNode, defaultPrices, recorder, gsocListener, defaultSigner, mock.WithClosestPeer(closestPeer)) + + // Trigger the sending of chunk to the closest node + receipt, err := psPivot.PushChunkToClosest(context.Background(), sch1) + if err != nil { + t.Fatal(err) + } + + if !sch1.Address().Equal(receipt.Address) { + t.Fatal("invalid receipt") + } + + // this intercepts the outgoing delivery message + waitOnRecordAndTest(t, closestPeer, recorder, sch1.Address(), sch1.Data()) + + // this intercepts the incoming receipt message + waitOnRecordAndTest(t, closestPeer, recorder, sch1.Address(), nil) + + recorder.Reset() + expectedPayload = chunk2.Data() + + // Trigger the sending of chunk to the closest node + receipt, err = psPivot.PushChunkToClosest(context.Background(), sch2) + if err != nil { + t.Fatal(err) + } + + if !sch2.Address().Equal(receipt.Address) { + t.Fatal("invalid receipt") + } + + // this intercepts the outgoing delivery message + waitOnRecordAndTest(t, closestPeer, recorder, sch2.Address(), sch2.Data()) + + // this intercepts the incoming receipt message + waitOnRecordAndTest(t, closestPeer, recorder, sch2.Address(), nil) +} + // TestShallowReceipt forces the peer to send back a shallow receipt to a pushsync request. In return, the origin node returns the error along with the received receipt. func TestShallowReceipt(t *testing.T) { t.Parallel() @@ -377,7 +468,7 @@ func TestPushChunkToClosestErrorAttemptRetry(t *testing.T) { }), ) - psPivot, pivotStorer := createPushSyncNodeWithAccounting(t, pivotNode, defaultPrices, recorder, nil, defaultSigner(chunk), pivotAccounting, log.Noop, mock.WithPeers(peer1, peer2, peer3, peer4)) + psPivot, pivotStorer := createPushSyncNodeWithAccounting(t, pivotNode, defaultPrices, recorder, nil, defaultSigner(chunk), pivotAccounting, log.Noop, func(soc.SOC) {}, mock.WithPeers(peer1, peer2, peer3, peer4)) // Trigger the sending of chunk to the closest node receipt, err := psPivot.PushChunkToClosest(context.Background(), chunk) @@ -554,15 +645,15 @@ func TestPropagateErrMsg(t *testing.T) { captureLogger := log.NewLogger("test", log.WithSink(buf)) // Create the closest peer - psClosestPeer, _ := createPushSyncNodeWithAccounting(t, closestPeer, defaultPrices, nil, nil, faultySigner, accountingmock.NewAccounting(), log.Noop, mock.WithClosestPeerErr(topology.ErrWantSelf)) + psClosestPeer, _ := createPushSyncNodeWithAccounting(t, closestPeer, defaultPrices, nil, nil, faultySigner, accountingmock.NewAccounting(), log.Noop, func(soc.SOC) {}, mock.WithClosestPeerErr(topology.ErrWantSelf)) // creating the pivot peer - psPivot, _ := createPushSyncNodeWithAccounting(t, pivotPeer, defaultPrices, nil, nil, defaultSigner(chunk), accountingmock.NewAccounting(), log.Noop, mock.WithPeers(closestPeer)) + psPivot, _ := createPushSyncNodeWithAccounting(t, pivotPeer, defaultPrices, nil, nil, defaultSigner(chunk), accountingmock.NewAccounting(), log.Noop, func(soc.SOC) {}, mock.WithPeers(closestPeer)) combinedRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol(), psClosestPeer.Protocol()), streamtest.WithBaseAddr(triggerPeer)) // Creating the trigger peer - psTriggerPeer, _ := createPushSyncNodeWithAccounting(t, triggerPeer, defaultPrices, combinedRecorder, nil, defaultSigner(chunk), accountingmock.NewAccounting(), captureLogger, mock.WithPeers(pivotPeer)) + psTriggerPeer, _ := createPushSyncNodeWithAccounting(t, triggerPeer, defaultPrices, combinedRecorder, nil, defaultSigner(chunk), accountingmock.NewAccounting(), captureLogger, func(soc.SOC) {}, mock.WithPeers(pivotPeer)) _, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk) if err == nil { @@ -738,7 +829,22 @@ func createPushSyncNode( ) (*pushsync.PushSync, *testStorer, accounting.Interface) { t.Helper() mockAccounting := accountingmock.NewAccounting() - ps, mstorer := createPushSyncNodeWithAccounting(t, addr, prices, recorder, unwrap, signer, mockAccounting, log.Noop, mockOpts...) + ps, mstorer := createPushSyncNodeWithAccounting(t, addr, prices, recorder, unwrap, signer, mockAccounting, log.Noop, func(soc.SOC) {}, mockOpts...) + return ps, mstorer, mockAccounting +} + +func createGsocPushSyncNode( + t *testing.T, + addr swarm.Address, + prices pricerParameters, + recorder *streamtest.Recorder, + gsocListener func(soc.SOC), + signer crypto.Signer, + mockOpts ...mock.Option, +) (*pushsync.PushSync, *testStorer, accounting.Interface) { + t.Helper() + mockAccounting := accountingmock.NewAccounting() + ps, mstorer := createPushSyncNodeWithAccounting(t, addr, prices, recorder, nil, signer, mockAccounting, log.Noop, gsocListener, mockOpts...) return ps, mstorer, mockAccounting } @@ -772,7 +878,7 @@ func createPushSyncNodeWithRadius( radiusFunc := func() (uint8, error) { return radius, nil } - ps := pushsync.New(addr, 1, blockHash.Bytes(), recorderDisconnecter, storer, radiusFunc, mockTopology, true, unwrap, validStamp, log.Noop, accountingmock.NewAccounting(), mockPricer, signer, nil, -1) + ps := pushsync.New(addr, 1, blockHash.Bytes(), recorderDisconnecter, storer, radiusFunc, mockTopology, true, unwrap, func(soc.SOC) {}, validStamp, log.Noop, accountingmock.NewAccounting(), mockPricer, signer, nil, -1) t.Cleanup(func() { ps.Close() }) return ps, storer @@ -787,6 +893,7 @@ func createPushSyncNodeWithAccounting( signer crypto.Signer, acct accounting.Interface, logger log.Logger, + gsocListener func(soc.SOC), mockOpts ...mock.Option, ) (*pushsync.PushSync, *testStorer) { t.Helper() @@ -802,6 +909,9 @@ func createPushSyncNodeWithAccounting( if unwrap == nil { unwrap = func(swarm.Chunk) {} } + if gsocListener == nil { + gsocListener = func(soc.SOC) {} + } validStamp := func(ch swarm.Chunk) (swarm.Chunk, error) { return ch, nil @@ -809,7 +919,7 @@ func createPushSyncNodeWithAccounting( radiusFunc := func() (uint8, error) { return 0, nil } - ps := pushsync.New(addr, 1, blockHash.Bytes(), recorderDisconnecter, storer, radiusFunc, mockTopology, true, unwrap, validStamp, logger, acct, mockPricer, signer, nil, -1) + ps := pushsync.New(addr, 1, blockHash.Bytes(), recorderDisconnecter, storer, radiusFunc, mockTopology, true, unwrap, gsocListener, validStamp, logger, acct, mockPricer, signer, nil, -1) t.Cleanup(func() { ps.Close() }) return ps, storer diff --git a/pkg/soc/utils.go b/pkg/soc/utils.go new file mode 100644 index 00000000000..f07c2f2102d --- /dev/null +++ b/pkg/soc/utils.go @@ -0,0 +1,31 @@ +// Copyright 2024 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package soc + +import "github.com/ethersphere/bee/v2/pkg/swarm" + +// IdentityAddress returns the internally used address for the chunk +func IdentityAddress(chunk swarm.Chunk) (swarm.Address, error) { + // check the chunk is single owner chunk or cac + if sch, err := FromChunk(chunk); err == nil { + socAddress, err := sch.Address() + if err != nil { + return swarm.ZeroAddress, err + } + h := swarm.NewHasher() + _, err = h.Write(socAddress.Bytes()) + if err != nil { + return swarm.ZeroAddress, err + } + _, err = h.Write(sch.WrappedChunk().Address().Bytes()) + if err != nil { + return swarm.ZeroAddress, err + } + + return swarm.NewAddress(h.Sum(nil)), nil + } + + return chunk.Address(), nil +} diff --git a/pkg/storer/internal/reserve/reserve.go b/pkg/storer/internal/reserve/reserve.go index 277e6b7c125..45af6919a4c 100644 --- a/pkg/storer/internal/reserve/reserve.go +++ b/pkg/storer/internal/reserve/reserve.go @@ -5,6 +5,7 @@ package reserve import ( + "bytes" "context" "encoding/binary" "encoding/hex" @@ -24,11 +25,27 @@ import ( "github.com/ethersphere/bee/v2/pkg/swarm" "github.com/ethersphere/bee/v2/pkg/topology" "golang.org/x/sync/errgroup" - "resenje.org/multex" ) const reserveScope = "reserve" +type multexLock struct { + mul map[string]struct{} + mu chan struct{} +} + +func (m *multexLock) Lock(id string) { + m.mu <- struct{}{} + m.mul[id] = struct{}{} + <-m.mu +} + +func (m *multexLock) Unlock(id string) { + m.mu <- struct{}{} + delete(m.mul, id) + <-m.mu +} + type Reserve struct { baseAddr swarm.Address radiusSetter topology.SetStorageRadiuser @@ -38,7 +55,7 @@ type Reserve struct { size atomic.Int64 radius atomic.Uint32 - multx *multex.Multex + multx multexLock st transaction.Storage } @@ -56,7 +73,7 @@ func New( capacity: capacity, radiusSetter: radiusSetter, logger: logger.WithName(reserveScope).Register(), - multx: multex.New(), + multx: multexLock{mul: make(map[string]struct{}), mu: make(chan struct{}, 1)}, } err := st.Run(context.Background(), func(s transaction.Store) error { @@ -100,9 +117,12 @@ func New( // if the new chunk has a higher stamp timestamp (regardless of batch type and chunk type, eg CAC & SOC). func (r *Reserve) Put(ctx context.Context, chunk swarm.Chunk) error { + chunkType := storage.ChunkType(chunk) + // batchID lock, Put vs Eviction - r.multx.Lock(string(chunk.Stamp().BatchID())) - defer r.multx.Unlock(string(chunk.Stamp().BatchID())) + lockId := lockId(chunk.Stamp()) + r.multx.Lock(lockId) + defer r.multx.Unlock(lockId) stampHash, err := chunk.Stamp().Hash() if err != nil { @@ -118,7 +138,6 @@ func (r *Reserve) Put(ctx context.Context, chunk swarm.Chunk) error { return nil } - chunkType := storage.ChunkType(chunk) bin := swarm.Proximity(r.baseAddr.Bytes(), chunk.Address().Bytes()) // bin lock @@ -145,95 +164,77 @@ func (r *Reserve) Put(ctx context.Context, chunk swarm.Chunk) error { if err != nil { return err } - prev := binary.BigEndian.Uint64(sameAddressOldStampIndex.StampTimestamp) - curr := binary.BigEndian.Uint64(chunk.Stamp().Timestamp()) - if prev >= curr { - return fmt.Errorf("overwrite same chunk. prev %d cur %d batch %s: %w", prev, curr, hex.EncodeToString(chunk.Stamp().BatchID()), storage.ErrOverwriteNewerChunk) - } - // index collision with another chunk - if loadedStampIndex { - prev := binary.BigEndian.Uint64(oldStampIndex.StampTimestamp) + // same index + if bytes.Equal(sameAddressOldStamp.Index(), chunk.Stamp().Index()) { + prev := binary.BigEndian.Uint64(sameAddressOldStampIndex.StampTimestamp) + curr := binary.BigEndian.Uint64(chunk.Stamp().Timestamp()) if prev >= curr { return fmt.Errorf("overwrite same chunk. prev %d cur %d batch %s: %w", prev, curr, hex.EncodeToString(chunk.Stamp().BatchID()), storage.ErrOverwriteNewerChunk) } - if !chunk.Address().Equal(oldStampIndex.ChunkAddress) { - r.logger.Debug( - "replacing chunk stamp index", - "old_chunk", oldStampIndex.ChunkAddress, - "new_chunk", chunk.Address(), - "batch_id", hex.EncodeToString(chunk.Stamp().BatchID()), - ) - // remove index items and chunk data - err = r.removeChunk(ctx, s, oldStampIndex.ChunkAddress, oldStampIndex.BatchID, oldStampIndex.StampHash) - if err != nil { - return fmt.Errorf("failed removing older chunk %s: %w", oldStampIndex.ChunkAddress, err) - } - shouldDecrReserveSize = true + + oldBatchRadiusItem := &BatchRadiusItem{ + Bin: bin, + Address: chunk.Address(), + BatchID: sameAddressOldStampIndex.BatchID, + StampHash: sameAddressOldStampIndex.StampHash, + } + // load item to get the binID + err = s.IndexStore().Get(oldBatchRadiusItem) + if err != nil { + return err } - } - oldBatchRadiusItem := &BatchRadiusItem{ - Bin: bin, - Address: chunk.Address(), - BatchID: sameAddressOldStampIndex.BatchID, - StampHash: sameAddressOldStampIndex.StampHash, - } - // load item to get the binID - err = s.IndexStore().Get(oldBatchRadiusItem) - if err != nil { - return err - } + // delete old chunk index items + err = errors.Join( + s.IndexStore().Delete(oldBatchRadiusItem), + s.IndexStore().Delete(&ChunkBinItem{Bin: oldBatchRadiusItem.Bin, BinID: oldBatchRadiusItem.BinID}), + stampindex.Delete(s.IndexStore(), reserveScope, sameAddressOldStamp), + chunkstamp.DeleteWithStamp(s.IndexStore(), reserveScope, oldBatchRadiusItem.Address, sameAddressOldStamp), + ) + if err != nil { + return err + } - // delete old chunk index items - err = errors.Join( - s.IndexStore().Delete(oldBatchRadiusItem), - s.IndexStore().Delete(&ChunkBinItem{Bin: oldBatchRadiusItem.Bin, BinID: oldBatchRadiusItem.BinID}), - stampindex.Delete(s.IndexStore(), reserveScope, sameAddressOldStamp), - chunkstamp.DeleteWithStamp(s.IndexStore(), reserveScope, oldBatchRadiusItem.Address, sameAddressOldStamp), - ) - if err != nil { - return err - } + binID, err := r.IncBinID(s.IndexStore(), bin) + if err != nil { + return err + } - binID, err := r.IncBinID(s.IndexStore(), bin) - if err != nil { - return err - } + err = errors.Join( + stampindex.Store(s.IndexStore(), reserveScope, chunk), + chunkstamp.Store(s.IndexStore(), reserveScope, chunk), + s.IndexStore().Put(&BatchRadiusItem{ + Bin: bin, + BinID: binID, + Address: chunk.Address(), + BatchID: chunk.Stamp().BatchID(), + StampHash: stampHash, + }), + s.IndexStore().Put(&ChunkBinItem{ + Bin: bin, + BinID: binID, + Address: chunk.Address(), + BatchID: chunk.Stamp().BatchID(), + ChunkType: chunkType, + StampHash: stampHash, + }), + ) + if err != nil { + return err + } - err = errors.Join( - stampindex.Store(s.IndexStore(), reserveScope, chunk), - chunkstamp.Store(s.IndexStore(), reserveScope, chunk), - s.IndexStore().Put(&BatchRadiusItem{ - Bin: bin, - BinID: binID, - Address: chunk.Address(), - BatchID: chunk.Stamp().BatchID(), - StampHash: stampHash, - }), - s.IndexStore().Put(&ChunkBinItem{ - Bin: bin, - BinID: binID, - Address: chunk.Address(), - BatchID: chunk.Stamp().BatchID(), - ChunkType: chunkType, - StampHash: stampHash, - }), - ) - if err != nil { - return err - } + if chunkType != swarm.ChunkTypeSingleOwner { + return nil + } - if chunkType != swarm.ChunkTypeSingleOwner { - return nil + r.logger.Debug("replacing soc in chunkstore", "address", chunk.Address()) + return s.ChunkStore().Replace(ctx, chunk) } - - r.logger.Debug("replacing soc in chunkstore", "address", chunk.Address()) - return s.ChunkStore().Replace(ctx, chunk) } // different address, same batch, index collision - if loadedStampIndex { + if loadedStampIndex && !chunk.Address().Equal(oldStampIndex.ChunkAddress) { prev := binary.BigEndian.Uint64(oldStampIndex.StampTimestamp) curr := binary.BigEndian.Uint64(chunk.Stamp().Timestamp()) if prev >= curr { @@ -677,3 +678,7 @@ func (r *Reserve) IncBinID(store storage.IndexStore, bin uint8) (uint64, error) return item.BinID, store.Put(item) } + +func lockId(stamp swarm.Stamp) string { + return fmt.Sprintf("%x-%x", stamp.BatchID(), stamp.Index()) +} diff --git a/pkg/storer/internal/reserve/reserve_test.go b/pkg/storer/internal/reserve/reserve_test.go index d13bb47ab75..7769af49389 100644 --- a/pkg/storer/internal/reserve/reserve_test.go +++ b/pkg/storer/internal/reserve/reserve_test.go @@ -197,12 +197,14 @@ func TestSameChunkAddress(t *testing.T) { bin := swarm.Proximity(baseAddr.Bytes(), ch1.Address().Bytes()) binBinIDs[bin] += 1 err = r.Put(ctx, ch2) - if !errors.Is(err, storage.ErrOverwriteNewerChunk) { - t.Fatal("expected error") + if err != nil { + t.Fatal(err) } + bin2 := swarm.Proximity(baseAddr.Bytes(), ch2.Address().Bytes()) + binBinIDs[bin2] += 1 size2 := r.Size() - if size2-size1 != 1 { - t.Fatalf("expected reserve size to increase by 1, got %d", size2-size1) + if size2-size1 != 2 { + t.Fatalf("expected reserve size to increase by 2, got %d", size2-size1) } }) @@ -269,11 +271,20 @@ func TestSameChunkAddress(t *testing.T) { s2 := soctesting.GenerateMockSocWithSigner(t, []byte("update"), signer) ch2 := s2.Chunk().WithStamp(postagetesting.MustNewFields(batch.ID, 1, 6)) bin := swarm.Proximity(baseAddr.Bytes(), ch1.Address().Bytes()) + err := r.Put(ctx, ch1) + if err != nil { + t.Fatal(err) + } + err = r.Put(ctx, ch2) + if err != nil { + t.Fatal(err) + } binBinIDs[bin] += 2 - replace(t, ch1, ch2, binBinIDs[bin]-1, binBinIDs[bin]) + checkChunkInIndexStore(t, ts.IndexStore(), bin, binBinIDs[bin]-1, ch1) + checkChunkInIndexStore(t, ts.IndexStore(), bin, binBinIDs[bin], ch2) size2 := r.Size() - if size2-size1 != 1 { - t.Fatalf("expected reserve size to increase by 1, got %d", size2-size1) + if size2-size1 != 2 { + t.Fatalf("expected reserve size to increase by 2, got %d", size2-size1) } }) @@ -435,16 +446,17 @@ func TestSameChunkAddress(t *testing.T) { ch3BinID := binBinIDs[bin2] checkStore(t, ts.IndexStore(), &reserve.BatchRadiusItem{Bin: bin1, BatchID: ch1.Stamp().BatchID(), Address: ch1.Address(), StampHash: ch1StampHash}, true) - checkStore(t, ts.IndexStore(), &reserve.BatchRadiusItem{Bin: bin2, BatchID: ch2.Stamp().BatchID(), Address: ch2.Address(), StampHash: ch2StampHash}, true) + // different index, same batch + checkStore(t, ts.IndexStore(), &reserve.BatchRadiusItem{Bin: bin2, BatchID: ch2.Stamp().BatchID(), Address: ch2.Address(), StampHash: ch2StampHash}, false) checkStore(t, ts.IndexStore(), &reserve.BatchRadiusItem{Bin: bin2, BatchID: ch3.Stamp().BatchID(), Address: ch3.Address(), StampHash: ch3StampHash}, false) checkStore(t, ts.IndexStore(), &reserve.ChunkBinItem{Bin: bin1, BinID: ch1BinID, StampHash: ch1StampHash}, true) - checkStore(t, ts.IndexStore(), &reserve.ChunkBinItem{Bin: bin2, BinID: ch2BinID, StampHash: ch2StampHash}, true) + checkStore(t, ts.IndexStore(), &reserve.ChunkBinItem{Bin: bin2, BinID: ch2BinID, StampHash: ch2StampHash}, false) checkStore(t, ts.IndexStore(), &reserve.ChunkBinItem{Bin: bin2, BinID: ch3BinID, StampHash: ch3StampHash}, false) size2 := r.Size() - // (ch1 + ch2) == 2 and then ch3 reduces reserve size by 1 - if size2-size1 != 1 { + // (ch1 + ch2) == 2 + if size2-size1 != 2 { t.Fatalf("expected reserve size to increase by 1, got %d", size2-size1) } }) @@ -923,3 +935,14 @@ func getSigner(t *testing.T) crypto.Signer { } return crypto.NewDefaultSigner(privKey) } + +func checkChunkInIndexStore(t *testing.T, s storage.Reader, bin uint8, binId uint64, ch swarm.Chunk) { + t.Helper() + stampHash, err := ch.Stamp().Hash() + if err != nil { + t.Fatal(err) + } + + checkStore(t, s, &reserve.BatchRadiusItem{Bin: bin, BatchID: ch.Stamp().BatchID(), Address: ch.Address(), StampHash: stampHash}, false) + checkStore(t, s, &reserve.ChunkBinItem{Bin: bin, BinID: binId, StampHash: stampHash}, false) +}