Skip to content

Commit

Permalink
Merge pull request #688 from luraproject/avoid_panics
Browse files Browse the repository at this point in the history
Avoid panics
  • Loading branch information
kpacha authored Sep 29, 2023
2 parents 965dacf + 8362d2b commit 69ca25e
Show file tree
Hide file tree
Showing 24 changed files with 240 additions and 202 deletions.
40 changes: 30 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ type ServiceConfig struct {
// run lura in debug mode
Debug bool `mapstructure:"debug_endpoint"`
Echo bool `mapstructure:"echo_endpoint"`
uriParser URIParser
uriParser SafeURIParser

// SequentialStart flags if the agents should be started sequentially
// before starting the router
Expand Down Expand Up @@ -361,7 +361,7 @@ func (s *ServiceConfig) Hash() (string, error) {
// Init also sanitizes the values, applies the default ones whenever necessary and
// normalizes all the things.
func (s *ServiceConfig) Init() error {
s.uriParser = NewURIParser()
s.uriParser = NewSafeURIParser()

if s.Version != ConfigVersion {
return &UnsupportedVersionError{
Expand All @@ -370,9 +370,13 @@ func (s *ServiceConfig) Init() error {
}
}

s.initGlobalParams()
if err := s.initGlobalParams(); err != nil {
return err
}

s.initAsyncAgents()
if err := s.initAsyncAgents(); err != nil {
return err
}

return s.initEndpoints()
}
Expand All @@ -393,7 +397,7 @@ func (s *ServiceConfig) Normalize() {
}
}

func (s *ServiceConfig) initGlobalParams() {
func (s *ServiceConfig) initGlobalParams() error {
if s.Port == 0 {
s.Port = defaultPort
}
Expand All @@ -404,8 +408,13 @@ func (s *ServiceConfig) initGlobalParams() {
s.Timeout = DefaultTimeout
}

s.Host = s.uriParser.CleanHosts(s.Host)
var err error
s.Host, err = s.uriParser.SafeCleanHosts(s.Host)
if err != nil {
return err
}
s.ExtraConfig.sanitize()
return nil
}

func (s *ServiceConfig) initAsyncAgents() error {
Expand All @@ -418,7 +427,11 @@ func (s *ServiceConfig) initAsyncAgents() error {
if len(b.Host) == 0 {
b.Host = s.Host
} else if !b.HostSanitizationDisabled {
b.Host = s.uriParser.CleanHosts(b.Host)
var err error
b.Host, err = s.uriParser.SafeCleanHosts(b.Host)
if err != nil {
return err
}
}
if b.Method == "" {
b.Method = http.MethodGet
Expand Down Expand Up @@ -461,7 +474,9 @@ func (s *ServiceConfig) initEndpoints() error {
e.ExtraConfig.sanitize()

for j, b := range e.Backend {
s.initBackendDefaults(i, j)
if err := s.initBackendDefaults(i, j); err != nil {
return err
}

if err := s.initBackendURLMappings(i, j, inputSet); err != nil {
return err
Expand Down Expand Up @@ -525,13 +540,17 @@ func (s *ServiceConfig) initAsyncAgentDefaults(e int) {
}
}

func (s *ServiceConfig) initBackendDefaults(e, b int) {
func (s *ServiceConfig) initBackendDefaults(e, b int) error {
endpoint := s.Endpoints[e]
backend := endpoint.Backend[b]
if len(backend.Host) == 0 {
backend.Host = s.Host
} else if !backend.HostSanitizationDisabled {
backend.Host = s.uriParser.CleanHosts(backend.Host)
var err error
backend.Host, err = s.uriParser.SafeCleanHosts(backend.Host)
if err != nil {
return err
}
}
if backend.Method == "" {
backend.Method = endpoint.Method
Expand All @@ -549,6 +568,7 @@ func (s *ServiceConfig) initBackendDefaults(e, b int) {
if backend.SDScheme == "" {
backend.SDScheme = "http"
}
return nil
}

func (s *ServiceConfig) initBackendURLMappings(e, b int, inputParams map[string]interface{}) error {
Expand Down
23 changes: 14 additions & 9 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package config

import (
"errors"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -61,7 +62,7 @@ func TestConfig_initBackendURLMappings_ok(t *testing.T) {

backend := Backend{}
endpoint := EndpointConfig{Backend: []*Backend{&backend}}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewURIParser()}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()}

inputSet := map[string]interface{}{
"tupu": nil,
Expand Down Expand Up @@ -89,7 +90,7 @@ func TestConfig_initBackendURLMappings_tooManyOutput(t *testing.T) {
Endpoint: "/some/{tupu}",
Backend: []*Backend{&backend},
}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewURIParser()}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()}

inputSet := map[string]interface{}{
"tupu": nil,
Expand All @@ -106,7 +107,7 @@ func TestConfig_initBackendURLMappings_tooManyOutput(t *testing.T) {
func TestConfig_initBackendURLMappings_undefinedOutput(t *testing.T) {
backend := Backend{URLPattern: "supu/{tupu_56}/{supu-5t6}?a={foo}&b={foo}"}
endpoint := EndpointConfig{Endpoint: "/", Method: "GET", Backend: []*Backend{&backend}}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewURIParser()}
subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()}

inputSet := map[string]interface{}{
"tupu": nil,
Expand Down Expand Up @@ -261,11 +262,6 @@ func TestConfig_initKOMultipleBackendsForNoopEncoder(t *testing.T) {
}

func TestConfig_initKOInvalidHost(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The init process did not panic with an invalid host!")
}
}()
subject := ServiceConfig{
Version: ConfigVersion,
Host: []string{"http://127.0.0.1:8080http://127.0.0.1:8080"},
Expand All @@ -278,7 +274,16 @@ func TestConfig_initKOInvalidHost(t *testing.T) {
},
}

subject.Init()
err := subject.Init()
if err == nil {
t.Errorf("expected to fail with invalid host")
return
}

if !errors.Is(err, errInvalidHost) {
t.Errorf("expected 'errInvalidHost' got: %s", err.Error())
return
}
}

func TestConfig_initKOInvalidDebugPattern(t *testing.T) {
Expand Down
55 changes: 47 additions & 8 deletions config/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package config

import (
"fmt"
"regexp"
"strings"
)
Expand All @@ -20,34 +21,72 @@ type URIParser interface {
GetEndpointPath(string, []string) string
}

// Like URIParser but with safe versions of the clean host functionality that
// does not panic but returns an error.
type SafeURIParser interface {
SafeCleanHosts([]string) ([]string, error)
SafeCleanHost(string) (string, error)
CleanPath(string) string
GetEndpointPath(string, []string) string
}

// NewURIParser creates a new URIParser using the package variable RoutingPattern
func NewURIParser() URIParser {
return URI(RoutingPattern)
}

// NewSafeURIParser creates a safe URI parser that does not panic when cleaning hosts
func NewSafeURIParser() URI {
return URI(RoutingPattern)
}

// URI implements the URIParser interface
type URI int

// CleanHosts applies the CleanHost method to every member of the received array of hosts
func (u URI) CleanHosts(hosts []string) []string {
// SafeCleanHosts applies the SafeCleanHost method to every member of the received array of hosts
func (u URI) SafeCleanHosts(hosts []string) ([]string, error) {
cleaned := make([]string, 0, len(hosts))
for i := range hosts {
cleaned = append(cleaned, u.CleanHost(hosts[i]))
h, err := u.SafeCleanHost(hosts[i])
if err != nil {
return nil, fmt.Errorf("host %s not valid: %w", hosts[i], errInvalidHost)
}
cleaned = append(cleaned, h)
}
return cleaned
return cleaned, nil
}

// CleanHost sanitizes the received host
func (URI) CleanHost(host string) string {
// CleanHosts applies the CleanHost method to every member of the received array of hosts
// Panics in case of error.
func (u URI) CleanHosts(hosts []string) []string {
ss, e := u.SafeCleanHosts(hosts)
if e != nil {
panic(e)
}
return ss
}

// SafeCleanHost sanitizes the received host
func (URI) SafeCleanHost(host string) (string, error) {
matches := hostPattern.FindAllStringSubmatch(host, -1)
if len(matches) != 1 {
panic(errInvalidHost)
return "", errInvalidHost
}
keys := matches[0][1:]
if keys[0] == "" {
keys[0] = "http://"
}
return strings.Join(keys, "")
return strings.Join(keys, ""), nil
}

// CleanHost sanitizes the received host.
// Panics on error.
func (u URI) CleanHost(host string) string {
h, err := u.SafeCleanHost(host)
if err != nil {
panic(err)
}
return h
}

// CleanPath trims all the extra slashes from the received URI path
Expand Down
48 changes: 43 additions & 5 deletions proxy/balancing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/luraproject/lura/v2/config"
"github.com/luraproject/lura/v2/logging"
"github.com/luraproject/lura/v2/sd"
)

Expand All @@ -20,7 +21,7 @@ func NewLoadBalancedMiddleware(remote *config.Backend) Middleware {
// NewLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding the most perfomant balancer
// over the received subscriber
func NewLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(sd.NewBalancer(subscriber))
return newLoadBalancedMiddleware(logging.NoOp, sd.NewBalancer(subscriber))
}

// NewRoundRobinLoadBalancedMiddleware creates proxy middleware adding a round robin balancer
Expand All @@ -38,19 +39,56 @@ func NewRandomLoadBalancedMiddleware(remote *config.Backend) Middleware {
// NewRoundRobinLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding a round robin
// balancer over the received subscriber
func NewRoundRobinLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(sd.NewRoundRobinLB(subscriber))
return newLoadBalancedMiddleware(logging.NoOp, sd.NewRoundRobinLB(subscriber))
}

// NewRandomLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding a random
// balancer over the received subscriber
func NewRandomLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(sd.NewRandomLB(subscriber))
return newLoadBalancedMiddleware(logging.NoOp, sd.NewRandomLB(subscriber))
}

func newLoadBalancedMiddleware(lb sd.Balancer) Middleware {
// NewLoadBalancedMiddlewareWithLogger creates proxy middleware adding the most perfomant balancer
// over a default subscriber
func NewLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware {
return NewLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote))
}

// NewLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding the most perfomant balancer
// over the received subscriber
func NewLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(l, sd.NewBalancer(subscriber))
}

// NewRoundRobinLoadBalancedMiddlewareWithLogger creates proxy middleware adding a round robin balancer
// over a default subscriber
func NewRoundRobinLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware {
return NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote))
}

// NewRandomLoadBalancedMiddlewareWithLogger creates proxy middleware adding a random balancer
// over a default subscriber
func NewRandomLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware {
return NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote))
}

// NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding a round robin
// balancer over the received subscriber
func NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(l, sd.NewRoundRobinLB(subscriber))
}

// NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding a random
// balancer over the received subscriber
func NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware {
return newLoadBalancedMiddleware(l, sd.NewRandomLB(subscriber))
}

func newLoadBalancedMiddleware(l logging.Logger, lb sd.Balancer) Middleware {
return func(next ...Proxy) Proxy {
if len(next) > 1 {
panic(ErrTooManyProxies)
l.Fatal("too many proxies for this proxy middleware: newLoadBalancedMiddleware only accepts 1 proxy, got %d", len(next))
return nil
}
return func(ctx context.Context, request *Request) (*Response, error) {
host, err := lb.Host()
Expand Down
6 changes: 4 additions & 2 deletions proxy/balancing_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ import (
"context"
"strconv"
"testing"

"github.com/luraproject/lura/v2/logging"
)

const veryLargeString = "abcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyz"

func BenchmarkNewLoadBalancedMiddleware(b *testing.B) {
for _, tc := range []int{3, 5, 9, 13, 17, 21, 25, 50, 100} {
b.Run(strconv.Itoa(tc), func(b *testing.B) {
proxy := newLoadBalancedMiddleware(dummyBalancer(veryLargeString[:tc]))(dummyProxy(&Response{}))
proxy := newLoadBalancedMiddleware(logging.NoOp, dummyBalancer(veryLargeString[:tc]))(dummyProxy(&Response{}))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
Expand Down Expand Up @@ -63,7 +65,7 @@ func BenchmarkNewLoadBalancedMiddleware_parallel100(b *testing.B) {

func benchmarkNewLoadBalancedMiddleware_parallel(b *testing.B, subject string) {
b.RunParallel(func(pb *testing.PB) {
proxy := newLoadBalancedMiddleware(dummyBalancer(subject))(dummyProxy(&Response{}))
proxy := newLoadBalancedMiddleware(logging.NoOp, dummyBalancer(subject))(dummyProxy(&Response{}))
for pb.Next() {
proxy(context.Background(), &Request{
Path: subject,
Expand Down
Loading

0 comments on commit 69ca25e

Please sign in to comment.