Skip to content

Commit

Permalink
feat(api): add middleware for handling route availability
Browse files Browse the repository at this point in the history
  • Loading branch information
gacevicljubisa committed Oct 19, 2024
1 parent 9451270 commit 5809000
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 97 deletions.
2 changes: 2 additions & 0 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ type Service struct {
redistributionAgent *storageincentives.Agent

statusService *status.Service

isFullApiAvailable bool
}

func (s *Service) SetP2P(p2p p2p.DebugService) {
Expand Down
12 changes: 5 additions & 7 deletions pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.
erc20 := erc20mock.New(o.Erc20Opts...)
backend := backendmock.New(o.BackendOpts...)

var extraOpts = api.ExtraOptions{
extraOpts := api.ExtraOptions{
TopologyDriver: topologyDriver,
Accounting: acc,
Pseudosettle: recipient,
Expand Down Expand Up @@ -231,9 +231,8 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.
WsPingPeriod: o.WsPingPeriod,
}, extraOpts, 1, erc20)

s.MountTechnicalDebug()
s.MountDebug()
s.MountAPI()
s.EnableFullAPIAvailability()

if o.DirectUpload {
chanStore = newChanStore(o.Storer.PusherFeed())
Expand Down Expand Up @@ -316,7 +315,7 @@ func TestParseName(t *testing.T) {
const bzzHash = "89c17d0d8018a19057314aa035e61c9d23c47581a61dd3a79a7839692c617e4d"
log := log.Noop

var errInvalidNameOrAddress = errors.New("invalid name or bzz address")
errInvalidNameOrAddress := errors.New("invalid name or bzz address")

testCases := []struct {
desc string
Expand Down Expand Up @@ -378,6 +377,7 @@ func TestParseName(t *testing.T) {
s := api.New(pk.PublicKey, pk.PublicKey, common.Address{}, nil, log, nil, nil, 1, false, false, nil, []string{"*"}, inmemstore.New())
s.Configure(signer, nil, api.Options{}, api.ExtraOptions{Resolver: tC.res}, 1, nil)
s.MountAPI()
s.EnableFullAPIAvailability()

tC := tC
t.Run(tC.desc, func(t *testing.T) {
Expand Down Expand Up @@ -503,9 +503,7 @@ func TestPostageHeaderError(t *testing.T) {
func TestOptions(t *testing.T) {
t.Parallel()

var (
client, _, _, _ = newTestServer(t, testServerOptions{})
)
client, _, _, _ := newTestServer(t, testServerOptions{})
for _, tc := range []struct {
endpoint string
expectedMethods string // expectedMethods contains HTTP methods like GET, POST, HEAD, PATCH, DELETE, OPTIONS. These are in alphabetical sorted order
Expand Down
140 changes: 67 additions & 73 deletions pkg/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,16 @@ const (
rootPath = "/" + apiVersion
)

func (s *Service) MountTechnicalDebug() {
func (s *Service) MountAPI() {
router := mux.NewRouter()

router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonhttp.ServiceUnavailable(w, "Node is syncing. This endpoint is unavailable. Try again later.")
})
router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler)

s.router = router

s.mountTechnicalDebug()

s.Handler = web.ChainHandlers(
httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"),
handlers.CompressHandler,
s.corsHandler,
web.NoCacheHeadersHandler,
web.FinalHandler(router),
)
}

func (s *Service) MountDebug() {
if s.router == nil {
s.router = mux.NewRouter()
}

s.mountBusinessDebug()

s.router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler)
s.mountAPI()

s.Handler = web.ChainHandlers(
httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"),
Expand All @@ -63,59 +45,55 @@ func (s *Service) MountDebug() {
)
}

func (s *Service) MountAPI() {
if s.router == nil {
s.router = mux.NewRouter()
}

s.mountAPI()

s.router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler)

compressHandler := func(h http.Handler) http.Handler {
downloadEndpoints := []string{
"/bzz",
"/bytes",
"/chunks",
"/feeds",
"/soc",
rootPath + "/bzz",
rootPath + "/bytes",
rootPath + "/chunks",
rootPath + "/feeds",
rootPath + "/soc",
}
func (s *Service) EnableFullAPIAvailability() {
if s != nil {
s.isFullApiAvailable = true

compressHandler := func(h http.Handler) http.Handler {
downloadEndpoints := []string{
"/bzz",
"/bytes",
"/chunks",
"/feeds",
"/soc",
rootPath + "/bzz",
rootPath + "/bytes",
rootPath + "/chunks",
rootPath + "/feeds",
rootPath + "/soc",
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip compression for GET requests on download endpoints.
// This is done in order to preserve Content-Length header in response,
// because CompressHandler is always removing it.
if r.Method == http.MethodGet {
for _, endpoint := range downloadEndpoints {
if strings.HasPrefix(r.URL.Path, endpoint) {
h.ServeHTTP(w, r)
return
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip compression for GET requests on download endpoints.
// This is done in order to preserve Content-Length header in response,
// because CompressHandler is always removing it.
if r.Method == http.MethodGet {
for _, endpoint := range downloadEndpoints {
if strings.HasPrefix(r.URL.Path, endpoint) {
h.ServeHTTP(w, r)
return
}
}
}
}

if r.Method == http.MethodHead {
h.ServeHTTP(w, r)
return
}
if r.Method == http.MethodHead {
h.ServeHTTP(w, r)
return
}

handlers.CompressHandler(h).ServeHTTP(w, r)
})
}
handlers.CompressHandler(h).ServeHTTP(w, r)
})
}

s.Handler = web.ChainHandlers(
httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"),
compressHandler,
s.responseCodeMetricsHandler,
s.pageviewMetricsHandler,
s.corsHandler,
web.FinalHandler(s.router),
)
s.Handler = web.ChainHandlers(
httpaccess.NewHTTPAccessLogHandler(s.logger, s.tracer, "api access"),
compressHandler,
s.responseCodeMetricsHandler,
s.pageviewMetricsHandler,
s.corsHandler,
web.FinalHandler(s.router),
)
}
}

func (s *Service) mountTechnicalDebug() {
Expand Down Expand Up @@ -151,11 +129,11 @@ func (s *Service) mountTechnicalDebug() {
u.Path += "/"
http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
}))

s.router.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
s.router.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile))
s.router.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol))
s.router.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace))

s.router.PathPrefix("/debug/pprof/").Handler(http.HandlerFunc(pprof.Index))
s.router.Handle("/debug/vars", expvar.Handler())

Expand All @@ -165,12 +143,14 @@ func (s *Service) mountTechnicalDebug() {
web.FinalHandlerFunc(s.loggerGetHandler),
),
})

s.router.Handle("/loggers/{exp}", jsonhttp.MethodHandler{
"GET": web.ChainHandlers(
httpaccess.NewHTTPAccessSuppressLogHandler(),
web.FinalHandlerFunc(s.loggerGetHandler),
),
})

s.router.Handle("/loggers/{exp}/{verbosity}", jsonhttp.MethodHandler{
"PUT": web.ChainHandlers(
httpaccess.NewHTTPAccessSuppressLogHandler(),
Expand All @@ -189,6 +169,16 @@ func (s *Service) mountTechnicalDebug() {
))
}

func (s *Service) checkRouteAvailability(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.isFullApiAvailable {
jsonhttp.ServiceUnavailable(w, "Node is syncing. This endpoint is unavailable. Try again later.")
return
}
handler.ServeHTTP(w, r)
})
}

func (s *Service) mountAPI() {
subdomainRouter := s.router.Host("{subdomain:.*}.swarm.localhost").Subrouter()

Expand All @@ -208,8 +198,9 @@ func (s *Service) mountAPI() {

// handle is a helper closure which simplifies the router setup.
handle := func(path string, handler http.Handler) {
s.router.Handle(path, handler)
s.router.Handle(rootPath+path, handler)
routeHandler := s.checkRouteAvailability(handler)
s.router.Handle(path, routeHandler)
s.router.Handle(rootPath+path, routeHandler)
}

handle("/bytes", jsonhttp.MethodHandler{
Expand Down Expand Up @@ -395,14 +386,16 @@ func (s *Service) mountAPI() {

func (s *Service) mountBusinessDebug() {
handle := func(path string, handler http.Handler) {
s.router.Handle(path, handler)
s.router.Handle(rootPath+path, handler)
routeHandler := s.checkRouteAvailability(handler)
s.router.Handle(path, routeHandler)
s.router.Handle(rootPath+path, routeHandler)
}

if s.transaction != nil {
handle("/transactions", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.transactionListHandler),
})

handle("/transactions/{hash}", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.transactionDetailHandler),
"POST": http.HandlerFunc(s.transactionResendHandler),
Expand Down Expand Up @@ -518,6 +511,7 @@ func (s *Service) mountBusinessDebug() {
handle("/wallet", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.walletHandler),
})

if s.swapEnabled {
handle("/wallet/withdraw/{coin}", jsonhttp.MethodHandler{
"POST": web.ChainHandlers(
Expand Down
23 changes: 9 additions & 14 deletions pkg/node/devnode.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) {
return nil, fmt.Errorf("blockchain address: %w", err)
}

var mockTransaction = transactionmock.New(transactionmock.WithPendingTransactionsFunc(func() ([]common.Hash, error) {
mockTransaction := transactionmock.New(transactionmock.WithPendingTransactionsFunc(func() ([]common.Hash, error) {
return []common.Hash{common.HexToHash("abcd")}, nil
}), transactionmock.WithResendTransactionFunc(func(ctx context.Context, txHash common.Hash) error {
return nil
Expand Down Expand Up @@ -303,13 +303,11 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) {
))
)

var (
// syncStatusFn mocks sync status because complete sync is required in order to curl certain apis e.g. /stamps.
// this allows accessing those apis by passing true to isDone in devNode.
syncStatusFn = func() (isDone bool, err error) {
return true, nil
}
)
// syncStatusFn mocks sync status because complete sync is required in order to curl certain apis e.g. /stamps.
// this allows accessing those apis by passing true to isDone in devNode.
syncStatusFn := func() (isDone bool, err error) {
return true, nil
}

mockFeeds := factory.New(localStore.Download(true))
mockResolver := resolverMock.NewResolver()
Expand Down Expand Up @@ -351,7 +349,7 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) {
SyncStatus: syncStatusFn,
}

var erc20 = erc20mock.New(
erc20 := erc20mock.New(
erc20mock.WithBalanceOfFunc(func(ctx context.Context, address common.Address) (*big.Int, error) {
return big.NewInt(0), nil
}),
Expand All @@ -366,10 +364,9 @@ func NewDevBee(logger log.Logger, o *DevOptions) (b *DevBee, err error) {
CORSAllowedOrigins: o.CORSAllowedOrigins,
WsPingPeriod: 60 * time.Second,
}, debugOpts, 1, erc20)
apiService.MountTechnicalDebug()
apiService.MountDebug()
apiService.MountAPI()

apiService.MountAPI()
apiService.EnableFullAPIAvailability()
apiService.SetProbe(probe)
apiService.SetP2P(p2ps)
apiService.SetSwarmAddress(&swarmAddress)
Expand Down Expand Up @@ -444,7 +441,6 @@ func pong(_ context.Context, _ swarm.Address, _ ...string) (rtt time.Duration, e
}

func randomAddress() (swarm.Address, error) {

b := make([]byte, 32)

_, err := rand.Read(b)
Expand All @@ -453,5 +449,4 @@ func randomAddress() (swarm.Address, error) {
}

return swarm.NewAddress(b), nil

}
6 changes: 3 additions & 3 deletions pkg/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ func NewBee(
o.CORSAllowedOrigins,
stamperStore,
)
apiService.MountTechnicalDebug()

apiService.MountAPI()
apiService.SetProbe(probe)

apiService.SetSwarmAddress(&swarmAddress)
Expand Down Expand Up @@ -1184,8 +1185,7 @@ func NewBee(
WsPingPeriod: 60 * time.Second,
}, extraOpts, chainID, erc20Service)

apiService.MountDebug()
apiService.MountAPI()
apiService.EnableFullAPIAvailability()

apiService.SetRedistributionAgent(agent)
}
Expand Down

0 comments on commit 5809000

Please sign in to comment.