From 5809000ec3126ff401345ccfab196ddeae346376 Mon Sep 17 00:00:00 2001 From: Ljubisa Date: Sat, 19 Oct 2024 23:39:08 +0200 Subject: [PATCH] feat(api): add middleware for handling route availability --- pkg/api/api.go | 2 + pkg/api/api_test.go | 12 ++-- pkg/api/router.go | 140 +++++++++++++++++++++----------------------- pkg/node/devnode.go | 23 +++----- pkg/node/node.go | 6 +- 5 files changed, 86 insertions(+), 97 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index 4568ad96bf5..61f46c4e4a8 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -216,6 +216,8 @@ type Service struct { redistributionAgent *storageincentives.Agent statusService *status.Service + + isFullApiAvailable bool } func (s *Service) SetP2P(p2p p2p.DebugService) { diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index 7bd20c07e87..1e3352a062f 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -179,7 +179,7 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket. erc20 := erc20mock.New(o.Erc20Opts...) backend := backendmock.New(o.BackendOpts...) - var extraOpts = api.ExtraOptions{ + extraOpts := api.ExtraOptions{ TopologyDriver: topologyDriver, Accounting: acc, Pseudosettle: recipient, @@ -231,9 +231,8 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket. WsPingPeriod: o.WsPingPeriod, }, extraOpts, 1, erc20) - s.MountTechnicalDebug() - s.MountDebug() s.MountAPI() + s.EnableFullAPIAvailability() if o.DirectUpload { chanStore = newChanStore(o.Storer.PusherFeed()) @@ -316,7 +315,7 @@ func TestParseName(t *testing.T) { const bzzHash = "89c17d0d8018a19057314aa035e61c9d23c47581a61dd3a79a7839692c617e4d" log := log.Noop - var errInvalidNameOrAddress = errors.New("invalid name or bzz address") + errInvalidNameOrAddress := errors.New("invalid name or bzz address") testCases := []struct { desc string @@ -378,6 +377,7 @@ func TestParseName(t *testing.T) { s := api.New(pk.PublicKey, pk.PublicKey, common.Address{}, nil, log, nil, nil, 1, false, false, nil, []string{"*"}, inmemstore.New()) s.Configure(signer, nil, api.Options{}, api.ExtraOptions{Resolver: tC.res}, 1, nil) s.MountAPI() + s.EnableFullAPIAvailability() tC := tC t.Run(tC.desc, func(t *testing.T) { @@ -503,9 +503,7 @@ func TestPostageHeaderError(t *testing.T) { func TestOptions(t *testing.T) { t.Parallel() - var ( - client, _, _, _ = newTestServer(t, testServerOptions{}) - ) + client, _, _, _ := newTestServer(t, testServerOptions{}) for _, tc := range []struct { endpoint string expectedMethods string // expectedMethods contains HTTP methods like GET, POST, HEAD, PATCH, DELETE, OPTIONS. These are in alphabetical sorted order diff --git a/pkg/api/router.go b/pkg/api/router.go index 6899b37e27f..055b3035c87 100644 --- a/pkg/api/router.go +++ b/pkg/api/router.go @@ -25,34 +25,16 @@ const ( rootPath = "/" + apiVersion ) -func (s *Service) MountTechnicalDebug() { +func (s *Service) MountAPI() { router := mux.NewRouter() - router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonhttp.ServiceUnavailable(w, "Node is syncing. This endpoint is unavailable. Try again later.") - }) + router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler) s.router = router s.mountTechnicalDebug() - - s.Handler = web.ChainHandlers( - httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"), - handlers.CompressHandler, - s.corsHandler, - web.NoCacheHeadersHandler, - web.FinalHandler(router), - ) -} - -func (s *Service) MountDebug() { - if s.router == nil { - s.router = mux.NewRouter() - } - s.mountBusinessDebug() - - s.router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler) + s.mountAPI() s.Handler = web.ChainHandlers( httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"), @@ -63,59 +45,55 @@ func (s *Service) MountDebug() { ) } -func (s *Service) MountAPI() { - if s.router == nil { - s.router = mux.NewRouter() - } - - s.mountAPI() - - s.router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler) - - compressHandler := func(h http.Handler) http.Handler { - downloadEndpoints := []string{ - "/bzz", - "/bytes", - "/chunks", - "/feeds", - "/soc", - rootPath + "/bzz", - rootPath + "/bytes", - rootPath + "/chunks", - rootPath + "/feeds", - rootPath + "/soc", - } +func (s *Service) EnableFullAPIAvailability() { + if s != nil { + s.isFullApiAvailable = true + + compressHandler := func(h http.Handler) http.Handler { + downloadEndpoints := []string{ + "/bzz", + "/bytes", + "/chunks", + "/feeds", + "/soc", + rootPath + "/bzz", + rootPath + "/bytes", + rootPath + "/chunks", + rootPath + "/feeds", + rootPath + "/soc", + } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Skip compression for GET requests on download endpoints. - // This is done in order to preserve Content-Length header in response, - // because CompressHandler is always removing it. - if r.Method == http.MethodGet { - for _, endpoint := range downloadEndpoints { - if strings.HasPrefix(r.URL.Path, endpoint) { - h.ServeHTTP(w, r) - return + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip compression for GET requests on download endpoints. + // This is done in order to preserve Content-Length header in response, + // because CompressHandler is always removing it. + if r.Method == http.MethodGet { + for _, endpoint := range downloadEndpoints { + if strings.HasPrefix(r.URL.Path, endpoint) { + h.ServeHTTP(w, r) + return + } } } - } - if r.Method == http.MethodHead { - h.ServeHTTP(w, r) - return - } + if r.Method == http.MethodHead { + h.ServeHTTP(w, r) + return + } - handlers.CompressHandler(h).ServeHTTP(w, r) - }) - } + handlers.CompressHandler(h).ServeHTTP(w, r) + }) + } - s.Handler = web.ChainHandlers( - httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"), - compressHandler, - s.responseCodeMetricsHandler, - s.pageviewMetricsHandler, - s.corsHandler, - web.FinalHandler(s.router), - ) + s.Handler = web.ChainHandlers( + httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"), + compressHandler, + s.responseCodeMetricsHandler, + s.pageviewMetricsHandler, + s.corsHandler, + web.FinalHandler(s.router), + ) + } } func (s *Service) mountTechnicalDebug() { @@ -151,11 +129,11 @@ func (s *Service) mountTechnicalDebug() { u.Path += "/" http.Redirect(w, r, u.String(), http.StatusPermanentRedirect) })) + s.router.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) s.router.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile)) s.router.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol)) s.router.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace)) - s.router.PathPrefix("/debug/pprof/").Handler(http.HandlerFunc(pprof.Index)) s.router.Handle("/debug/vars", expvar.Handler()) @@ -165,12 +143,14 @@ func (s *Service) mountTechnicalDebug() { web.FinalHandlerFunc(s.loggerGetHandler), ), }) + s.router.Handle("/loggers/{exp}", jsonhttp.MethodHandler{ "GET": web.ChainHandlers( httpaccess.NewHTTPAccessSuppressLogHandler(), web.FinalHandlerFunc(s.loggerGetHandler), ), }) + s.router.Handle("/loggers/{exp}/{verbosity}", jsonhttp.MethodHandler{ "PUT": web.ChainHandlers( httpaccess.NewHTTPAccessSuppressLogHandler(), @@ -189,6 +169,16 @@ func (s *Service) mountTechnicalDebug() { )) } +func (s *Service) checkRouteAvailability(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !s.isFullApiAvailable { + jsonhttp.ServiceUnavailable(w, "Node is syncing. This endpoint is unavailable. Try again later.") + return + } + handler.ServeHTTP(w, r) + }) +} + func (s *Service) mountAPI() { subdomainRouter := s.router.Host("{subdomain:.*}.swarm.localhost").Subrouter() @@ -208,8 +198,9 @@ func (s *Service) mountAPI() { // handle is a helper closure which simplifies the router setup. handle := func(path string, handler http.Handler) { - s.router.Handle(path, handler) - s.router.Handle(rootPath+path, handler) + routeHandler := s.checkRouteAvailability(handler) + s.router.Handle(path, routeHandler) + s.router.Handle(rootPath+path, routeHandler) } handle("/bytes", jsonhttp.MethodHandler{ @@ -395,14 +386,16 @@ func (s *Service) mountAPI() { func (s *Service) mountBusinessDebug() { handle := func(path string, handler http.Handler) { - s.router.Handle(path, handler) - s.router.Handle(rootPath+path, handler) + routeHandler := s.checkRouteAvailability(handler) + s.router.Handle(path, routeHandler) + s.router.Handle(rootPath+path, routeHandler) } if s.transaction != nil { handle("/transactions", jsonhttp.MethodHandler{ "GET": http.HandlerFunc(s.transactionListHandler), }) + handle("/transactions/{hash}", jsonhttp.MethodHandler{ "GET": http.HandlerFunc(s.transactionDetailHandler), "POST": http.HandlerFunc(s.transactionResendHandler), @@ -518,6 +511,7 @@ func (s *Service) mountBusinessDebug() { handle("/wallet", jsonhttp.MethodHandler{ "GET": http.HandlerFunc(s.walletHandler), }) + if s.swapEnabled { handle("/wallet/withdraw/{coin}", jsonhttp.MethodHandler{ "POST": web.ChainHandlers( diff --git a/pkg/node/devnode.go b/pkg/node/devnode.go index 5439fe2d2a7..f2aab84d5a9 100644 --- a/pkg/node/devnode.go +++ b/pkg/node/devnode.go @@ -137,7 +137,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { return nil, fmt.Errorf("blockchain address: %w", err) } - var mockTransaction = transactionmock.New(transactionmock.WithPendingTransactionsFunc(func() ([]common.Hash, error) { + mockTransaction := transactionmock.New(transactionmock.WithPendingTransactionsFunc(func() ([]common.Hash, error) { return []common.Hash{common.HexToHash("abcd")}, nil }), transactionmock.WithResendTransactionFunc(func(ctx context.Context, txHash common.Hash) error { return nil @@ -303,13 +303,11 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { )) ) - var ( - // syncStatusFn mocks sync status because complete sync is required in order to curl certain apis e.g. /stamps. - // this allows accessing those apis by passing true to isDone in devNode. - syncStatusFn = func() (isDone bool, err error) { - return true, nil - } - ) + // syncStatusFn mocks sync status because complete sync is required in order to curl certain apis e.g. /stamps. + // this allows accessing those apis by passing true to isDone in devNode. + syncStatusFn := func() (isDone bool, err error) { + return true, nil + } mockFeeds := factory.New(localStore.Download(true)) mockResolver := resolverMock.NewResolver() @@ -351,7 +349,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { SyncStatus: syncStatusFn, } - var erc20 = erc20mock.New( + erc20 := erc20mock.New( erc20mock.WithBalanceOfFunc(func(ctx context.Context, address common.Address) (*big.Int, error) { return big.NewInt(0), nil }), @@ -366,10 +364,9 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) { CORSAllowedOrigins: o.CORSAllowedOrigins, WsPingPeriod: 60 * time.Second, }, debugOpts, 1, erc20) - apiService.MountTechnicalDebug() - apiService.MountDebug() - apiService.MountAPI() + apiService.MountAPI() + apiService.EnableFullAPIAvailability() apiService.SetProbe(probe) apiService.SetP2P(p2ps) apiService.SetSwarmAddress(&swarmAddress) @@ -444,7 +441,6 @@ func pong(_ context.Context, _ swarm.Address, _ ...string) (rtt time.Duration, e } func randomAddress() (swarm.Address, error) { - b := make([]byte, 32) _, err := rand.Read(b) @@ -453,5 +449,4 @@ func randomAddress() (swarm.Address, error) { } return swarm.NewAddress(b), nil - } diff --git a/pkg/node/node.go b/pkg/node/node.go index b8498aa868a..e9c490804c4 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -450,7 +450,8 @@ func NewBee( o.CORSAllowedOrigins, stamperStore, ) - apiService.MountTechnicalDebug() + + apiService.MountAPI() apiService.SetProbe(probe) apiService.SetSwarmAddress(&swarmAddress) @@ -1184,8 +1185,7 @@ func NewBee( WsPingPeriod: 60 * time.Second, }, extraOpts, chainID, erc20Service) - apiService.MountDebug() - apiService.MountAPI() + apiService.EnableFullAPIAvailability() apiService.SetRedistributionAgent(agent) }