package main import ( "sync" "time" ) // AdvancedRateLimiter implements token bucket rate limiting for client connections type AdvancedRateLimiter struct { maxTokens int refillRate time.Duration tokens int lastRefill time.Time mu sync.Mutex } // ClientRateLimiter manages per-client rate limiting type ClientRateLimiter struct { limiters map[string]*AdvancedRateLimiter mu sync.RWMutex cleanup chan string } // NewAdvancedRateLimiter creates a new rate limiter func NewAdvancedRateLimiter(maxTokens int, refillRate time.Duration) *AdvancedRateLimiter { return &AdvancedRateLimiter{ maxTokens: maxTokens, refillRate: refillRate, tokens: maxTokens, lastRefill: time.Now(), } } // Allow checks if an action is allowed (consumes a token) func (rl *AdvancedRateLimiter) Allow() bool { rl.mu.Lock() defer rl.mu.Unlock() rl.refill() if rl.tokens > 0 { rl.tokens-- return true } return false } // refill adds tokens based on elapsed time func (rl *AdvancedRateLimiter) refill() { now := time.Now() elapsed := now.Sub(rl.lastRefill) tokensToAdd := int(elapsed / rl.refillRate) if tokensToAdd > 0 { rl.tokens += tokensToAdd if rl.tokens > rl.maxTokens { rl.tokens = rl.maxTokens } rl.lastRefill = now } } // GetTokens returns current token count (for monitoring) func (rl *AdvancedRateLimiter) GetTokens() int { rl.mu.Lock() defer rl.mu.Unlock() rl.refill() return rl.tokens } // NewClientRateLimiter creates a new client rate limiter func NewClientRateLimiter() *ClientRateLimiter { crl := &ClientRateLimiter{ limiters: make(map[string]*AdvancedRateLimiter), cleanup: make(chan string, 1000), } // Start cleanup routine go crl.cleanupRoutine() return crl } // Allow checks if a client action is allowed func (crl *ClientRateLimiter) Allow(clientID string, maxTokens int, refillRate time.Duration) bool { crl.mu.RLock() limiter, exists := crl.limiters[clientID] crl.mu.RUnlock() if !exists { crl.mu.Lock() // Double-check after acquiring write lock limiter, exists = crl.limiters[clientID] if !exists { limiter = NewAdvancedRateLimiter(maxTokens, refillRate) crl.limiters[clientID] = limiter } crl.mu.Unlock() } return limiter.Allow() } // RemoveClient removes a client's rate limiter func (crl *ClientRateLimiter) RemoveClient(clientID string) { select { case crl.cleanup <- clientID: // Queued for cleanup default: // Cleanup queue full, clean directly crl.mu.Lock() delete(crl.limiters, clientID) crl.mu.Unlock() } } // cleanupRoutine processes client cleanup requests func (crl *ClientRateLimiter) cleanupRoutine() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case clientID := <-crl.cleanup: crl.mu.Lock() delete(crl.limiters, clientID) crl.mu.Unlock() case <-ticker.C: // Periodic cleanup of old rate limiters crl.mu.Lock() for clientID, limiter := range crl.limiters { // Remove limiters that haven't been used recently if time.Since(limiter.lastRefill) > 10*time.Minute { delete(crl.limiters, clientID) } } crl.mu.Unlock() } } } // GetStats returns rate limiter statistics func (crl *ClientRateLimiter) GetStats() map[string]interface{} { crl.mu.RLock() defer crl.mu.RUnlock() stats := make(map[string]interface{}) stats["active_limiters"] = len(crl.limiters) totalTokens := 0 for _, limiter := range crl.limiters { totalTokens += limiter.GetTokens() } stats["total_tokens"] = totalTokens return stats } // GlobalRateLimiter implements server-wide rate limiting type GlobalRateLimiter struct { connectionLimiter *AdvancedRateLimiter messageLimiter *AdvancedRateLimiter } // NewGlobalRateLimiter creates a new global rate limiter func NewGlobalRateLimiter(connPerSecond, msgPerSecond int) *GlobalRateLimiter { return &GlobalRateLimiter{ connectionLimiter: NewAdvancedRateLimiter(connPerSecond*10, time.Second/time.Duration(connPerSecond)), messageLimiter: NewAdvancedRateLimiter(msgPerSecond*10, time.Second/time.Duration(msgPerSecond)), } } // AllowConnection checks if a new connection is allowed func (grl *GlobalRateLimiter) AllowConnection() bool { return grl.connectionLimiter.Allow() } // AllowMessage checks if a message is allowed func (grl *GlobalRateLimiter) AllowMessage() bool { return grl.messageLimiter.Allow() } // GetConnectionTokens returns available connection tokens func (grl *GlobalRateLimiter) GetConnectionTokens() int { return grl.connectionLimiter.GetTokens() } // GetMessageTokens returns available message tokens func (grl *GlobalRateLimiter) GetMessageTokens() int { return grl.messageLimiter.GetTokens() }