Added all of the existing code
This commit is contained in:
192
rate_limiter.go
Normal file
192
rate_limiter.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AdvancedRateLimiter implements token bucket rate limiting for client connections
|
||||
type AdvancedRateLimiter struct {
|
||||
maxTokens int
|
||||
refillRate time.Duration
|
||||
tokens int
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ClientRateLimiter manages per-client rate limiting
|
||||
type ClientRateLimiter struct {
|
||||
limiters map[string]*AdvancedRateLimiter
|
||||
mu sync.RWMutex
|
||||
cleanup chan string
|
||||
}
|
||||
|
||||
// NewAdvancedRateLimiter creates a new rate limiter
|
||||
func NewAdvancedRateLimiter(maxTokens int, refillRate time.Duration) *AdvancedRateLimiter {
|
||||
return &AdvancedRateLimiter{
|
||||
maxTokens: maxTokens,
|
||||
refillRate: refillRate,
|
||||
tokens: maxTokens,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if an action is allowed (consumes a token)
|
||||
func (rl *AdvancedRateLimiter) Allow() bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.refill()
|
||||
|
||||
if rl.tokens > 0 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// refill adds tokens based on elapsed time
|
||||
func (rl *AdvancedRateLimiter) refill() {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(rl.lastRefill)
|
||||
|
||||
tokensToAdd := int(elapsed / rl.refillRate)
|
||||
if tokensToAdd > 0 {
|
||||
rl.tokens += tokensToAdd
|
||||
if rl.tokens > rl.maxTokens {
|
||||
rl.tokens = rl.maxTokens
|
||||
}
|
||||
rl.lastRefill = now
|
||||
}
|
||||
}
|
||||
|
||||
// GetTokens returns current token count (for monitoring)
|
||||
func (rl *AdvancedRateLimiter) GetTokens() int {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
rl.refill()
|
||||
return rl.tokens
|
||||
}
|
||||
|
||||
// NewClientRateLimiter creates a new client rate limiter
|
||||
func NewClientRateLimiter() *ClientRateLimiter {
|
||||
crl := &ClientRateLimiter{
|
||||
limiters: make(map[string]*AdvancedRateLimiter),
|
||||
cleanup: make(chan string, 1000),
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go crl.cleanupRoutine()
|
||||
|
||||
return crl
|
||||
}
|
||||
|
||||
// Allow checks if a client action is allowed
|
||||
func (crl *ClientRateLimiter) Allow(clientID string, maxTokens int, refillRate time.Duration) bool {
|
||||
crl.mu.RLock()
|
||||
limiter, exists := crl.limiters[clientID]
|
||||
crl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
crl.mu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
limiter, exists = crl.limiters[clientID]
|
||||
if !exists {
|
||||
limiter = NewAdvancedRateLimiter(maxTokens, refillRate)
|
||||
crl.limiters[clientID] = limiter
|
||||
}
|
||||
crl.mu.Unlock()
|
||||
}
|
||||
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// RemoveClient removes a client's rate limiter
|
||||
func (crl *ClientRateLimiter) RemoveClient(clientID string) {
|
||||
select {
|
||||
case crl.cleanup <- clientID:
|
||||
// Queued for cleanup
|
||||
default:
|
||||
// Cleanup queue full, clean directly
|
||||
crl.mu.Lock()
|
||||
delete(crl.limiters, clientID)
|
||||
crl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoutine processes client cleanup requests
|
||||
func (crl *ClientRateLimiter) cleanupRoutine() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case clientID := <-crl.cleanup:
|
||||
crl.mu.Lock()
|
||||
delete(crl.limiters, clientID)
|
||||
crl.mu.Unlock()
|
||||
|
||||
case <-ticker.C:
|
||||
// Periodic cleanup of old rate limiters
|
||||
crl.mu.Lock()
|
||||
for clientID, limiter := range crl.limiters {
|
||||
// Remove limiters that haven't been used recently
|
||||
if time.Since(limiter.lastRefill) > 10*time.Minute {
|
||||
delete(crl.limiters, clientID)
|
||||
}
|
||||
}
|
||||
crl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns rate limiter statistics
|
||||
func (crl *ClientRateLimiter) GetStats() map[string]interface{} {
|
||||
crl.mu.RLock()
|
||||
defer crl.mu.RUnlock()
|
||||
|
||||
stats := make(map[string]interface{})
|
||||
stats["active_limiters"] = len(crl.limiters)
|
||||
|
||||
totalTokens := 0
|
||||
for _, limiter := range crl.limiters {
|
||||
totalTokens += limiter.GetTokens()
|
||||
}
|
||||
stats["total_tokens"] = totalTokens
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GlobalRateLimiter implements server-wide rate limiting
|
||||
type GlobalRateLimiter struct {
|
||||
connectionLimiter *AdvancedRateLimiter
|
||||
messageLimiter *AdvancedRateLimiter
|
||||
}
|
||||
|
||||
// NewGlobalRateLimiter creates a new global rate limiter
|
||||
func NewGlobalRateLimiter(connPerSecond, msgPerSecond int) *GlobalRateLimiter {
|
||||
return &GlobalRateLimiter{
|
||||
connectionLimiter: NewAdvancedRateLimiter(connPerSecond*10, time.Second/time.Duration(connPerSecond)),
|
||||
messageLimiter: NewAdvancedRateLimiter(msgPerSecond*10, time.Second/time.Duration(msgPerSecond)),
|
||||
}
|
||||
}
|
||||
|
||||
// AllowConnection checks if a new connection is allowed
|
||||
func (grl *GlobalRateLimiter) AllowConnection() bool {
|
||||
return grl.connectionLimiter.Allow()
|
||||
}
|
||||
|
||||
// AllowMessage checks if a message is allowed
|
||||
func (grl *GlobalRateLimiter) AllowMessage() bool {
|
||||
return grl.messageLimiter.Allow()
|
||||
}
|
||||
|
||||
// GetConnectionTokens returns available connection tokens
|
||||
func (grl *GlobalRateLimiter) GetConnectionTokens() int {
|
||||
return grl.connectionLimiter.GetTokens()
|
||||
}
|
||||
|
||||
// GetMessageTokens returns available message tokens
|
||||
func (grl *GlobalRateLimiter) GetMessageTokens() int {
|
||||
return grl.messageLimiter.GetTokens()
|
||||
}
|
||||
Reference in New Issue
Block a user