Files
techircd/rate_limiter.go

193 lines
4.6 KiB
Go

package main
import (
"sync"
"time"
)
// AdvancedRateLimiter implements token bucket rate limiting for client connections
type AdvancedRateLimiter struct {
maxTokens int
refillRate time.Duration
tokens int
lastRefill time.Time
mu sync.Mutex
}
// ClientRateLimiter manages per-client rate limiting
type ClientRateLimiter struct {
limiters map[string]*AdvancedRateLimiter
mu sync.RWMutex
cleanup chan string
}
// NewAdvancedRateLimiter creates a new rate limiter
func NewAdvancedRateLimiter(maxTokens int, refillRate time.Duration) *AdvancedRateLimiter {
return &AdvancedRateLimiter{
maxTokens: maxTokens,
refillRate: refillRate,
tokens: maxTokens,
lastRefill: time.Now(),
}
}
// Allow checks if an action is allowed (consumes a token)
func (rl *AdvancedRateLimiter) Allow() bool {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refill()
if rl.tokens > 0 {
rl.tokens--
return true
}
return false
}
// refill adds tokens based on elapsed time
func (rl *AdvancedRateLimiter) refill() {
now := time.Now()
elapsed := now.Sub(rl.lastRefill)
tokensToAdd := int(elapsed / rl.refillRate)
if tokensToAdd > 0 {
rl.tokens += tokensToAdd
if rl.tokens > rl.maxTokens {
rl.tokens = rl.maxTokens
}
rl.lastRefill = now
}
}
// GetTokens returns current token count (for monitoring)
func (rl *AdvancedRateLimiter) GetTokens() int {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.refill()
return rl.tokens
}
// NewClientRateLimiter creates a new client rate limiter
func NewClientRateLimiter() *ClientRateLimiter {
crl := &ClientRateLimiter{
limiters: make(map[string]*AdvancedRateLimiter),
cleanup: make(chan string, 1000),
}
// Start cleanup routine
go crl.cleanupRoutine()
return crl
}
// Allow checks if a client action is allowed
func (crl *ClientRateLimiter) Allow(clientID string, maxTokens int, refillRate time.Duration) bool {
crl.mu.RLock()
limiter, exists := crl.limiters[clientID]
crl.mu.RUnlock()
if !exists {
crl.mu.Lock()
// Double-check after acquiring write lock
limiter, exists = crl.limiters[clientID]
if !exists {
limiter = NewAdvancedRateLimiter(maxTokens, refillRate)
crl.limiters[clientID] = limiter
}
crl.mu.Unlock()
}
return limiter.Allow()
}
// RemoveClient removes a client's rate limiter
func (crl *ClientRateLimiter) RemoveClient(clientID string) {
select {
case crl.cleanup <- clientID:
// Queued for cleanup
default:
// Cleanup queue full, clean directly
crl.mu.Lock()
delete(crl.limiters, clientID)
crl.mu.Unlock()
}
}
// cleanupRoutine processes client cleanup requests
func (crl *ClientRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case clientID := <-crl.cleanup:
crl.mu.Lock()
delete(crl.limiters, clientID)
crl.mu.Unlock()
case <-ticker.C:
// Periodic cleanup of old rate limiters
crl.mu.Lock()
for clientID, limiter := range crl.limiters {
// Remove limiters that haven't been used recently
if time.Since(limiter.lastRefill) > 10*time.Minute {
delete(crl.limiters, clientID)
}
}
crl.mu.Unlock()
}
}
}
// GetStats returns rate limiter statistics
func (crl *ClientRateLimiter) GetStats() map[string]interface{} {
crl.mu.RLock()
defer crl.mu.RUnlock()
stats := make(map[string]interface{})
stats["active_limiters"] = len(crl.limiters)
totalTokens := 0
for _, limiter := range crl.limiters {
totalTokens += limiter.GetTokens()
}
stats["total_tokens"] = totalTokens
return stats
}
// GlobalRateLimiter implements server-wide rate limiting
type GlobalRateLimiter struct {
connectionLimiter *AdvancedRateLimiter
messageLimiter *AdvancedRateLimiter
}
// NewGlobalRateLimiter creates a new global rate limiter
func NewGlobalRateLimiter(connPerSecond, msgPerSecond int) *GlobalRateLimiter {
return &GlobalRateLimiter{
connectionLimiter: NewAdvancedRateLimiter(connPerSecond*10, time.Second/time.Duration(connPerSecond)),
messageLimiter: NewAdvancedRateLimiter(msgPerSecond*10, time.Second/time.Duration(msgPerSecond)),
}
}
// AllowConnection checks if a new connection is allowed
func (grl *GlobalRateLimiter) AllowConnection() bool {
return grl.connectionLimiter.Allow()
}
// AllowMessage checks if a message is allowed
func (grl *GlobalRateLimiter) AllowMessage() bool {
return grl.messageLimiter.Allow()
}
// GetConnectionTokens returns available connection tokens
func (grl *GlobalRateLimiter) GetConnectionTokens() int {
return grl.connectionLimiter.GetTokens()
}
// GetMessageTokens returns available message tokens
func (grl *GlobalRateLimiter) GetMessageTokens() int {
return grl.messageLimiter.GetTokens()
}