diff --git a/client.go b/client.go index 14321d1b48..d1c8549a55 100644 --- a/client.go +++ b/client.go @@ -380,7 +380,7 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b // and AcquireResponse in performance-critical code. func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { req.timeout = timeout - if req.timeout < 0 { + if req.timeout <= 0 { return ErrTimeout } return c.Do(req, resp) @@ -412,7 +412,7 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) // and AcquireResponse in performance-critical code. func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error { req.timeout = time.Until(deadline) - if req.timeout < 0 { + if req.timeout <= 0 { return ErrTimeout } return c.Do(req, resp) @@ -1158,7 +1158,7 @@ func ReleaseResponse(resp *Response) { // and AcquireResponse in performance-critical code. func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { req.timeout = timeout - if req.timeout < 0 { + if req.timeout <= 0 { return ErrTimeout } return c.Do(req, resp) @@ -1185,7 +1185,7 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati // and AcquireResponse in performance-critical code. func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { req.timeout = time.Until(deadline) - if req.timeout < 0 { + if req.timeout <= 0 { return ErrTimeout } return c.Do(req, resp) @@ -1243,8 +1243,27 @@ func (c *HostClient) Do(req *Request, resp *Response) error { attempts := 0 hasBodyStream := req.IsBodyStream() + // If a request has a timeout we store the timeout + // and calculate a deadline so we can keep updating the + // timeout on each retry. + deadline := time.Time{} + timeout := req.timeout + if timeout > 0 { + deadline = time.Now().Add(timeout) + } + atomic.AddInt32(&c.pendingRequests, 1) for { + // If the original timeout was set, we need to update + // the one set on the request to reflect the remaining time. + if timeout > 0 { + req.timeout = time.Until(deadline) + if req.timeout <= 0 { + err = ErrTimeout + break + } + } + retry, err = c.do(req, resp) if err == nil || !retry { break @@ -1272,6 +1291,9 @@ func (c *HostClient) Do(req *Request, resp *Response) error { } atomic.AddInt32(&c.pendingRequests, -1) + // Restore the original timeout. + req.timeout = timeout + if err == io.EOF { err = ErrConnectionClosed } @@ -2288,7 +2310,7 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t c.init() timeout := time.Until(deadline) - if timeout < 0 { + if timeout <= 0 { return ErrTimeout } diff --git a/client_test.go b/client_test.go index 74b87a241d..9ff3bda955 100644 --- a/client_test.go +++ b/client_test.go @@ -146,6 +146,46 @@ func TestHostClientNegativeTimeout(t *testing.T) { ln.Close() } +func TestDoDeadlineRetry(t *testing.T) { + t.Parallel() + + tries := 0 + done := make(chan struct{}) + + ln := fasthttputil.NewInmemoryListener() + go func() { + for { + c, err := ln.Accept() + if err != nil { + close(done) + break + } + tries++ + br := bufio.NewReader(c) + (&RequestHeader{}).Read(br) //nolint:errcheck + (&Request{}).readBodyStream(br, 0, false, false) //nolint:errcheck + time.Sleep(time.Millisecond * 60) + c.Close() + } + }() + c := &HostClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com") + if err := c.DoDeadline(req, nil, time.Now().Add(time.Millisecond*100)); err != ErrTimeout { + t.Fatalf("expected ErrTimeout error got: %+v", err) + } + ln.Close() + <-done + if tries != 2 { + t.Fatalf("expected 2 tries got %d", tries) + } +} + func TestPipelineClientIssue832(t *testing.T) { t.Parallel()