Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] token bucket fill (#9) #9

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,28 @@ func main() {
r := gin.Default()

client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Addr: "localhost:6379",
})

// Call NewRateLimiter function from rrl package.
// First parameter is the Redis client.
// Second parameter is the rate (tokens per second).
// Third parameter is the maximum number of tokens.
limiter := rrl.NewRateLimiter(client, 1, 10)
// Fourth parameter is time duration, token refill is depending on x time interval
limiter := rrl.NewRateLimiter(client, 1, 10, 30*time.Second)

// Use RateLimiterMiddleware from rrl package and pass limiter.
// This middleware works for all routes in your application,
// including static files served when you open a web browser.
r.Use(rrl.RateLimiterMiddleware(limiter, 1))
r.Use(rrl.RateLimiterMiddleware(limiter))

r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Welcome!"})
c.JSON(http.StatusOK, gin.H{"message": "Welcome!"})
})

// Using this way allows the RateLimiterMiddleware to work for only specific routes.
r.GET("/some", rrl.RateLimiterMiddleware(limiter, 1), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Some!"})
r.GET("/some", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Some!"})
})

r.Run(":8080")
Expand Down
69 changes: 31 additions & 38 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
b64 "encoding/base64"
"errors"
"log"
"math"
"net/http"
"os"
"sync"
Expand All @@ -15,7 +16,10 @@ import (
)

// define constant variable of keyPrefix to avoid duplicate key in Redis
const keyPrefix = "ls_prefix:"
const (
keyPrefix = "ls_prefix:"
lastRefillPrefix = "_lastRefillTime"
)

// RateLimiter is struct based on Redis
type RateLimiter struct {
Expand All @@ -29,7 +33,7 @@ type RateLimiter struct {
currentToken int64

// lastRefillTime represents time that this bucket fill operation was tried
lastRefillTime time.Time
refillInterval time.Duration

mutex sync.Mutex

Expand All @@ -46,12 +50,12 @@ func encodeKey(value string) string {
}

// NewRateLimiter to received and define new RateLimiter struct
func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter {
func NewRateLimiter(client *redis.Client, rate, maxToken int64, refillInterval time.Duration) *RateLimiter {
return &RateLimiter{
client: client,
rate: rate,
maxTokens: maxToken,
lastRefillTime: time.Now(),
refillInterval: refillInterval,
currentToken: maxToken,
logger: log.New(os.Stdout, "RateLimiter: ", log.Lmicroseconds),
}
Expand All @@ -67,23 +71,24 @@ func NewRateLimiter(client *redis.Client, rate, maxToken int64) *RateLimiter {
// Returns:
//
// bool: Returns true if the request is allowed, false otherwise.
func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
func (rl *RateLimiter) IsRequestAllowed(key string, token int64) bool {
// use mutex to avoid race condition
rl.mutex.Lock()
defer rl.mutex.Unlock()

// encode key
sEnc := keyPrefix + encodeKey(key)

// get current token count from Redis
tokenCount, err := rl.client.Get(context.Background(), sEnc).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
rl.logger.Printf("Error getting token count from Redis: %v", err)
return false
}

// get last refill time from Redis
lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+"_lastRefillTime").Result()
if errors.Is(err, redis.Nil) {
tokenCount = rl.maxTokens
}

lastRefillTimeStr, err := rl.client.Get(context.Background(), sEnc+lastRefillPrefix).Result()
var lastRefillTime time.Time
if err == nil {
lastRefillTime, err = time.Parse(time.RFC3339, lastRefillTimeStr)
Expand All @@ -94,27 +99,20 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
} else if !errors.Is(err, redis.Nil) {
rl.logger.Printf("Error getting last refill time from Redis: %v", err)
return false
} else {
lastRefillTime = time.Now()
}

// refill tokens
tokenCount, lastRefillTime = rl.refillBucket(lastRefillTime, tokenCount)

// update last refill time in Redis
rl.client.Set(context.Background(), sEnc+"_lastRefillTime", lastRefillTime.Format(time.RFC3339), 0)
tokenCount = rl.refill(tokenCount, lastRefillTime)

// check if enough tokens are available
if tokenCount > 0 {
// decrement token count
tokenCount--
// update token count in Redis
err = rl.client.Set(context.Background(), sEnc, tokenCount, 0).Err()
if err != nil {
rl.logger.Printf("Error setting token count in Redis: %v", err)
return false
}
if tokenCount >= token {
tokenCount -= token
rl.client.Set(context.Background(), sEnc, tokenCount, 0)
rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return true
}

rl.client.Set(context.Background(), sEnc+lastRefillPrefix, time.Now().Format(time.RFC3339), 0)
return false
}

Expand All @@ -128,11 +126,11 @@ func (rl *RateLimiter) IsRequestAllowed(key string, tokens int64) bool {
// Returns:
//
// gin.HandlerFunc: A Gin handler function that can be used as middleware in the Gin router.
func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc {
func RateLimiterMiddleware(limiter *RateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()

if !limiter.IsRequestAllowed(ip, tokens) {
token := int64(1)
if !limiter.IsRequestAllowed(ip, token) {
limiter.logger.Printf("Rate limit exceeded for IP: %s", ip)
c.Header("X-RateLimit-Remaining", "0")
c.JSON(http.StatusTooManyRequests, gin.H{"error": "too many requests"})
Expand All @@ -144,18 +142,13 @@ func RateLimiterMiddleware(limiter *RateLimiter, tokens int64) gin.HandlerFunc {
}
}

// refillBucket function calculate time, when token bucket can refill
func (rl *RateLimiter) refillBucket(lastRefillTime time.Time, tokenCount int64) (int64, time.Time) {
func (rl *RateLimiter) refill(currentTokens int64, lastRefillTime time.Time) int64 {
now := time.Now()
duration := now.Sub(lastRefillTime)
elapsed := now.Sub(lastRefillTime)

// Calculate tokens to add based on elapsed time and rate
tokensToAdd := (duration.Nanoseconds() * rl.rate) / 1e9 // maybe this calculation isn't correct, but i try to avoid float64, because sometimes it not accuracy

tokenCount = tokenCount + tokensToAdd
if tokenCount > rl.maxTokens {
tokenCount = rl.maxTokens
}
// calculate time which each token needs to refill in token bucket
tokensToAdd := elapsed.Nanoseconds() / rl.refillInterval.Nanoseconds()
newTokens := int64(math.Min(float64(currentTokens+tokensToAdd), float64(rl.maxTokens)))

return tokenCount, now
return newTokens
}
2 changes: 1 addition & 1 deletion limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func setupRedisClient() *redis.Client {

func TestRateLimiter_Allow(t *testing.T) {
client := setupRedisClient()
limiter := NewRateLimiter(client, 1, 5)
limiter := NewRateLimiter(client, 1, 5, time.Second)

tests := []struct {
name string
Expand Down
Loading