193 lines
4.6 KiB
Go
193 lines
4.6 KiB
Go
package main
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// AdvancedRateLimiter implements token bucket rate limiting for client connections
|
|
type AdvancedRateLimiter struct {
|
|
maxTokens int
|
|
refillRate time.Duration
|
|
tokens int
|
|
lastRefill time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// ClientRateLimiter manages per-client rate limiting
|
|
type ClientRateLimiter struct {
|
|
limiters map[string]*AdvancedRateLimiter
|
|
mu sync.RWMutex
|
|
cleanup chan string
|
|
}
|
|
|
|
// NewAdvancedRateLimiter creates a new rate limiter
|
|
func NewAdvancedRateLimiter(maxTokens int, refillRate time.Duration) *AdvancedRateLimiter {
|
|
return &AdvancedRateLimiter{
|
|
maxTokens: maxTokens,
|
|
refillRate: refillRate,
|
|
tokens: maxTokens,
|
|
lastRefill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Allow checks if an action is allowed (consumes a token)
|
|
func (rl *AdvancedRateLimiter) Allow() bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
rl.refill()
|
|
|
|
if rl.tokens > 0 {
|
|
rl.tokens--
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// refill adds tokens based on elapsed time
|
|
func (rl *AdvancedRateLimiter) refill() {
|
|
now := time.Now()
|
|
elapsed := now.Sub(rl.lastRefill)
|
|
|
|
tokensToAdd := int(elapsed / rl.refillRate)
|
|
if tokensToAdd > 0 {
|
|
rl.tokens += tokensToAdd
|
|
if rl.tokens > rl.maxTokens {
|
|
rl.tokens = rl.maxTokens
|
|
}
|
|
rl.lastRefill = now
|
|
}
|
|
}
|
|
|
|
// GetTokens returns current token count (for monitoring)
|
|
func (rl *AdvancedRateLimiter) GetTokens() int {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
rl.refill()
|
|
return rl.tokens
|
|
}
|
|
|
|
// NewClientRateLimiter creates a new client rate limiter
|
|
func NewClientRateLimiter() *ClientRateLimiter {
|
|
crl := &ClientRateLimiter{
|
|
limiters: make(map[string]*AdvancedRateLimiter),
|
|
cleanup: make(chan string, 1000),
|
|
}
|
|
|
|
// Start cleanup routine
|
|
go crl.cleanupRoutine()
|
|
|
|
return crl
|
|
}
|
|
|
|
// Allow checks if a client action is allowed
|
|
func (crl *ClientRateLimiter) Allow(clientID string, maxTokens int, refillRate time.Duration) bool {
|
|
crl.mu.RLock()
|
|
limiter, exists := crl.limiters[clientID]
|
|
crl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
crl.mu.Lock()
|
|
// Double-check after acquiring write lock
|
|
limiter, exists = crl.limiters[clientID]
|
|
if !exists {
|
|
limiter = NewAdvancedRateLimiter(maxTokens, refillRate)
|
|
crl.limiters[clientID] = limiter
|
|
}
|
|
crl.mu.Unlock()
|
|
}
|
|
|
|
return limiter.Allow()
|
|
}
|
|
|
|
// RemoveClient removes a client's rate limiter
|
|
func (crl *ClientRateLimiter) RemoveClient(clientID string) {
|
|
select {
|
|
case crl.cleanup <- clientID:
|
|
// Queued for cleanup
|
|
default:
|
|
// Cleanup queue full, clean directly
|
|
crl.mu.Lock()
|
|
delete(crl.limiters, clientID)
|
|
crl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// cleanupRoutine processes client cleanup requests
|
|
func (crl *ClientRateLimiter) cleanupRoutine() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case clientID := <-crl.cleanup:
|
|
crl.mu.Lock()
|
|
delete(crl.limiters, clientID)
|
|
crl.mu.Unlock()
|
|
|
|
case <-ticker.C:
|
|
// Periodic cleanup of old rate limiters
|
|
crl.mu.Lock()
|
|
for clientID, limiter := range crl.limiters {
|
|
// Remove limiters that haven't been used recently
|
|
if time.Since(limiter.lastRefill) > 10*time.Minute {
|
|
delete(crl.limiters, clientID)
|
|
}
|
|
}
|
|
crl.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetStats returns rate limiter statistics
|
|
func (crl *ClientRateLimiter) GetStats() map[string]interface{} {
|
|
crl.mu.RLock()
|
|
defer crl.mu.RUnlock()
|
|
|
|
stats := make(map[string]interface{})
|
|
stats["active_limiters"] = len(crl.limiters)
|
|
|
|
totalTokens := 0
|
|
for _, limiter := range crl.limiters {
|
|
totalTokens += limiter.GetTokens()
|
|
}
|
|
stats["total_tokens"] = totalTokens
|
|
|
|
return stats
|
|
}
|
|
|
|
// GlobalRateLimiter implements server-wide rate limiting
|
|
type GlobalRateLimiter struct {
|
|
connectionLimiter *AdvancedRateLimiter
|
|
messageLimiter *AdvancedRateLimiter
|
|
}
|
|
|
|
// NewGlobalRateLimiter creates a new global rate limiter
|
|
func NewGlobalRateLimiter(connPerSecond, msgPerSecond int) *GlobalRateLimiter {
|
|
return &GlobalRateLimiter{
|
|
connectionLimiter: NewAdvancedRateLimiter(connPerSecond*10, time.Second/time.Duration(connPerSecond)),
|
|
messageLimiter: NewAdvancedRateLimiter(msgPerSecond*10, time.Second/time.Duration(msgPerSecond)),
|
|
}
|
|
}
|
|
|
|
// AllowConnection checks if a new connection is allowed
|
|
func (grl *GlobalRateLimiter) AllowConnection() bool {
|
|
return grl.connectionLimiter.Allow()
|
|
}
|
|
|
|
// AllowMessage checks if a message is allowed
|
|
func (grl *GlobalRateLimiter) AllowMessage() bool {
|
|
return grl.messageLimiter.Allow()
|
|
}
|
|
|
|
// GetConnectionTokens returns available connection tokens
|
|
func (grl *GlobalRateLimiter) GetConnectionTokens() int {
|
|
return grl.connectionLimiter.GetTokens()
|
|
}
|
|
|
|
// GetMessageTokens returns available message tokens
|
|
func (grl *GlobalRateLimiter) GetMessageTokens() int {
|
|
return grl.messageLimiter.GetTokens()
|
|
}
|