package main import ( "crypto/tls" "fmt" "log" "net" "strings" "sync" "time" ) type Server struct { config *Config clients map[string]*Client channels map[string]*Channel listener net.Listener sslListener net.Listener serverListener net.Listener linkedServers map[string]*LinkedServer pingMessage string mu sync.RWMutex shutdown chan bool healthMonitor *HealthMonitor } func NewServer(config *Config) *Server { server := &Server{ config: config, clients: make(map[string]*Client, config.Limits.MaxClients), channels: make(map[string]*Channel, config.Limits.MaxChannels), shutdown: make(chan bool), } server.healthMonitor = NewHealthMonitor(server) return server } func (s *Server) Start() error { // Start regular listener addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.Port) listener, err := net.Listen("tcp", addr) if err != nil { return fmt.Errorf("failed to listen on %s: %v", addr, err) } s.listener = listener log.Printf("IRC server listening on %s", addr) // Start health monitoring s.healthMonitor.Start() // Start SSL listener if enabled if s.config.Server.Listen.EnableSSL { go s.startSSLListener() } // Start server linking if enabled if s.config.Linking.Enable { go s.startServerListener() go s.startAutoConnections() } // Auto-create configured channels for _, channelName := range s.config.Channels.AutoJoin { channel := NewChannel(channelName) // Set default modes for _, mode := range s.config.Channels.DefaultModes { if mode != '+' { channel.SetMode(rune(mode), true) } } s.channels[strings.ToLower(channelName)] = channel } // Start ping routine go s.pingRoutine() // Accept connections for { select { case <-s.shutdown: return nil default: conn, err := listener.Accept() if err != nil { continue } client := NewClient(conn, s) s.AddClient(client) go client.Handle() } } } func (s *Server) startSSLListener() { // Load SSL certificates cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile) if err != nil { log.Printf("Failed to load SSL certificates: %v", err) return } tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort) listener, err := tls.Listen("tcp", addr, tlsConfig) if err != nil { log.Printf("Failed to start SSL listener on %s: %v", addr, err) return } s.sslListener = listener log.Printf("IRC SSL server listening on %s", addr) for { select { case <-s.shutdown: return default: conn, err := listener.Accept() if err != nil { continue } client := NewClient(conn, s) s.AddClient(client) go client.Handle() } } } func (s *Server) pingRoutine() { ticker := time.NewTicker(60 * time.Second) // Reduced frequency to prevent contention defer ticker.Stop() // Add a health check ticker that runs less frequently healthTicker := time.NewTicker(5 * time.Minute) // Much less frequent health checks defer healthTicker.Stop() for { select { case <-s.shutdown: return case <-ticker.C: // Run ping check in a goroutine to prevent blocking go s.performPingCheck() case <-healthTicker.C: // Run health check in a goroutine to prevent blocking go s.performHealthCheck() } } } // performPingCheck handles the periodic ping checking func (s *Server) performPingCheck() { defer func() { if r := recover(); r != nil { log.Printf("Panic in performPingCheck: %v", r) } }() // Get a snapshot of clients without holding the lock for long s.mu.RLock() clientIDs := make([]string, 0, len(s.clients)) for clientID := range s.clients { clientIDs = append(clientIDs, clientID) } s.mu.RUnlock() // Process clients individually to prevent blocking for _, clientID := range clientIDs { func() { // Get client safely s.mu.RLock() client := s.clients[clientID] s.mu.RUnlock() if client == nil || !client.IsRegistered() || !client.IsConnected() { return } // Check activity without holding client lock for long client.mu.RLock() lastActivity := client.lastActivity client.mu.RUnlock() // Only send ping if client hasn't been active recently if time.Since(lastActivity) > 120*time.Second { // Send ping in a non-blocking way go func() { defer func() { if r := recover(); r != nil { log.Printf("Panic sending ping to %s: %v", client.getClientInfo(), r) } }() client.SendMessage(s.pingMessage) // Use pre-computed message }() } }() } } // performHealthCheck checks all clients for health issues func (s *Server) performHealthCheck() { defer func() { if r := recover(); r != nil { log.Printf("Panic in performHealthCheck: %v", r) } }() // Get a snapshot of clients without holding the lock for long s.mu.RLock() clientIDs := make([]string, 0, len(s.clients)) for clientID := range s.clients { clientIDs = append(clientIDs, clientID) } totalClients := len(s.clients) s.mu.RUnlock() unhealthyClients := 0 disconnectedClients := []string{} // Store client IDs instead of pointers // Process clients in batches to prevent overwhelming the system batchSize := 50 for i := 0; i < len(clientIDs); i += batchSize { end := i + batchSize if end > len(clientIDs) { end = len(clientIDs) } batch := clientIDs[i:end] for _, clientID := range batch { func() { // Get client safely s.mu.RLock() client := s.clients[clientID] s.mu.RUnlock() if client == nil { return } healthy, reason := client.HealthCheck() if !healthy { unhealthyClients++ // Force disconnect clients that are definitely problematic if strings.Contains(reason, "disconnected") || strings.Contains(reason, "nil") || strings.Contains(reason, "write errors") || strings.Contains(reason, "registration timeout") { disconnectedClients = append(disconnectedClients, clientID) log.Printf("Marking client %s for disconnection: %s", client.getClientInfo(), reason) } } }() } // Small delay between batches to prevent overwhelming time.Sleep(10 * time.Millisecond) } // Disconnect problematic clients in a separate goroutine if len(disconnectedClients) > 0 { go func() { defer func() { if r := recover(); r != nil { log.Printf("Panic during client disconnection: %v", r) } }() for _, clientID := range disconnectedClients { s.mu.RLock() client := s.clients[clientID] s.mu.RUnlock() if client != nil { client.ForceDisconnect("Connection health check failed") } } }() } // Log health statistics if totalClients > 0 { healthyPercentage := float64(totalClients-unhealthyClients) / float64(totalClients) * 100 log.Printf("Client health check: %d total, %d healthy (%.1f%%), %d marked for disconnection", totalClients, totalClients-unhealthyClients, healthyPercentage, len(disconnectedClients)) } // Log detailed server statistics every 5 minutes for monitoring s.LogServerStats() } func (s *Server) AddClient(client *Client) { s.mu.Lock() defer s.mu.Unlock() // Enhanced validation before adding client if client == nil { log.Printf("Attempted to add nil client") return } if client.conn == nil { log.Printf("Attempted to add client with nil connection") return } // Check client limit if len(s.clients) >= s.config.Limits.MaxClients { log.Printf("Server full, rejecting client from %s", client.getClientInfo()) client.SendMessage("ERROR :Server full") if client.conn != nil { client.conn.Close() } return } // Check for duplicate client IDs (shouldn't happen but defensive programming) if _, exists := s.clients[client.clientID]; exists { log.Printf("Duplicate client ID detected: %s, generating new ID", client.clientID) // Generate a new unique ID client.clientID = fmt.Sprintf("%s_%d_%d", client.host, time.Now().Unix(), len(s.clients)) } s.clients[client.clientID] = client log.Printf("Added client %s (total clients: %d/%d)", client.getClientInfo(), len(s.clients), s.config.Limits.MaxClients) } func (s *Server) RemoveClient(client *Client) { s.mu.Lock() delete(s.clients, client.clientID) s.mu.Unlock() // Send snomask notification for client disconnect (after releasing the lock) if client.IsRegistered() { s.sendSnomask('c', fmt.Sprintf("Client disconnect: %s (%s@%s)", client.Nick(), client.User(), client.Host())) } } // GetServerStats returns detailed server statistics for monitoring func (s *Server) GetServerStats() map[string]interface{} { s.mu.RLock() defer s.mu.RUnlock() stats := make(map[string]interface{}) // Basic counts stats["total_clients"] = len(s.clients) stats["total_channels"] = len(s.channels) // Client statistics registeredClients := 0 operatorClients := 0 sslClients := 0 unhealthyClients := 0 for _, client := range s.clients { if client.IsRegistered() { registeredClients++ } if client.IsOper() { operatorClients++ } if client.IsSSL() { sslClients++ } // Quick health check if healthy, _ := client.HealthCheck(); !healthy { unhealthyClients++ } } stats["registered_clients"] = registeredClients stats["operator_clients"] = operatorClients stats["ssl_clients"] = sslClients stats["unhealthy_clients"] = unhealthyClients // Channel statistics totalChannelUsers := 0 for _, channel := range s.channels { totalChannelUsers += len(channel.GetClients()) } stats["total_channel_users"] = totalChannelUsers // Server linking statistics stats["linked_servers"] = len(s.linkedServers) return stats } // LogServerStats logs current server statistics func (s *Server) LogServerStats() { stats := s.GetServerStats() log.Printf("Server Statistics: %d clients (%d registered, %d operators, %d SSL, %d unhealthy), %d channels, %d linked servers", stats["total_clients"], stats["registered_clients"], stats["operator_clients"], stats["ssl_clients"], stats["unhealthy_clients"], stats["total_channels"], stats["linked_servers"]) } // sendSnomask sends a server notice to operators watching a specific snomask func (s *Server) sendSnomask(snomask rune, message string) { s.mu.RLock() defer s.mu.RUnlock() for _, client := range s.clients { if client.IsOper() && client.HasSnomask(snomask) { client.SendMessage(fmt.Sprintf(":%s NOTICE %s :*** %s", s.config.Server.Name, client.Nick(), message)) } } } // ReloadConfig reloads the server configuration func (s *Server) ReloadConfig() error { config, err := LoadConfig("config.json") if err != nil { return fmt.Errorf("failed to load config: %v", err) } s.mu.Lock() s.config = config s.mu.Unlock() return nil } func (s *Server) GetClient(nick string) *Client { s.mu.RLock() defer s.mu.RUnlock() for _, client := range s.clients { if strings.EqualFold(client.Nick(), nick) { return client } } return nil } func (s *Server) GetClientByHost(host string) *Client { s.mu.RLock() defer s.mu.RUnlock() for _, client := range s.clients { if client.Host() == host { return client } } return nil } func (s *Server) GetClientByID(clientID string) *Client { s.mu.RLock() defer s.mu.RUnlock() return s.clients[clientID] } func (s *Server) GetClients() map[string]*Client { s.mu.RLock() defer s.mu.RUnlock() clients := make(map[string]*Client) for clientID, client := range s.clients { clients[clientID] = client } return clients } func (s *Server) GetChannel(name string) *Channel { s.mu.RLock() defer s.mu.RUnlock() return s.channels[strings.ToLower(name)] } func (s *Server) GetOrCreateChannel(name string) *Channel { s.mu.Lock() defer s.mu.Unlock() channelName := strings.ToLower(name) if channel, exists := s.channels[channelName]; exists { return channel } // Create new channel channel := NewChannel(name) // Set default modes for _, mode := range s.config.Channels.DefaultModes { if mode != '+' { channel.SetMode(rune(mode), true) } } s.channels[channelName] = channel return channel } func (s *Server) RemoveChannel(name string) { s.mu.Lock() defer s.mu.Unlock() delete(s.channels, strings.ToLower(name)) } func (s *Server) GetChannels() map[string]*Channel { s.mu.RLock() defer s.mu.RUnlock() channels := make(map[string]*Channel) for name, channel := range s.channels { channels[name] = channel } return channels } func (s *Server) CreateChannel(name string) *Channel { s.mu.Lock() defer s.mu.Unlock() if len(s.channels) >= s.config.Limits.MaxChannels { return nil } channel := NewChannel(name) s.channels[strings.ToLower(name)] = channel return channel } func (s *Server) GetClientCount() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.clients) } func (s *Server) GetChannelCount() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.channels) } func (s *Server) IsNickInUse(nick string) bool { s.mu.RLock() defer s.mu.RUnlock() for _, client := range s.clients { if strings.EqualFold(client.Nick(), nick) { return true } } return false } func (s *Server) HandleMessage(client *Client, message string) { // Parse IRCv3 message tags if present var tags map[string]string var actualMessage string if strings.HasPrefix(message, "@") { // Message has IRCv3 tags spaceIndex := strings.Index(message, " ") if spaceIndex == -1 { // Message is only tags (like typing indicators), this is valid - just ignore if DebugMode { log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message) } return } tagString := message[1:spaceIndex] // Remove @ prefix actualMessage = strings.TrimSpace(message[spaceIndex+1:]) // Parse tags tags = make(map[string]string) tagPairs := strings.Split(tagString, ";") for _, pair := range tagPairs { if strings.Contains(pair, "=") { kv := strings.SplitN(pair, "=", 2) tags[kv[0]] = kv[1] } else { tags[pair] = "" } } // If actualMessage is empty or just a colon, it's a tags-only message if actualMessage == "" || actualMessage == ":" { if DebugMode { log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message) } return } } else { actualMessage = message } parts := strings.Fields(actualMessage) if len(parts) == 0 { return } command := strings.ToUpper(parts[0]) // Log the command for debugging if DebugMode { log.Printf("<<< RECV from %s: %s", client.Host(), message) } else { log.Printf("Client %s: %s", client.Host(), message) } switch command { case "CAP": client.handleCap(parts) case "NICK": client.handleNick(parts) case "USER": client.handleUser(parts) case "PING": client.handlePing(parts) case "PONG": client.handlePong(parts) case "JOIN": client.handleJoin(parts) case "PART": client.handlePart(parts) case "PRIVMSG": client.handlePrivmsg(parts) case "NOTICE": client.handleNotice(parts) case "TAGMSG": client.handleTagmsg(parts, tags) case "WHO": client.handleWho(parts) case "WHOIS": client.handleWhois(parts) case "NAMES": client.handleNames(parts) case "MODE": client.handleMode(parts) case "OPER": client.handleOper(parts) case "SNOMASK": client.handleSnomask(parts) case "GLOBALNOTICE": client.handleGlobalNotice(parts) case "OPERWALL": client.handleOperWall(parts) case "WALLOPS": client.handleWallops(parts) case "REHASH": client.handleRehash() case "TRACE": client.handleTrace(parts) case "HELPOP": client.handleHelpop(parts) case "TOPIC": client.handleTopic(parts) case "KICK": client.handleKick(parts) case "INVITE": client.handleInvite(parts) case "AWAY": client.handleAway(parts) case "LIST": client.handleList() case "KILL": client.handleKill(parts) case "QUIT": client.handleQuit(parts) case "CONNECT": client.handleConnect(parts) case "SQUIT": client.handleSquit(parts) case "LINKS": client.handleLinks() case "USERHOST": client.handleUserhost(parts) case "ISON": client.handleIson(parts) case "TIME": client.handleTime() case "VERSION": client.handleVersion() case "ADMIN": client.handleAdmin() case "INFO": client.handleInfo() case "LUSERS": client.handleLusers() case "STATS": client.handleStats(parts) case "SILENCE": client.handleSilence(parts) case "MONITOR": client.handleMonitor(parts) case "AUTHENTICATE": client.handleAuthenticate(parts) // Services/Admin Commands case "CHGHOST": client.handleChghost(parts) case "SVSNICK": client.handleSvsnick(parts) case "SVSMODE": client.handleSvsmode(parts) case "SAMODE": client.handleSamode(parts) case "SANICK": client.handleSanick(parts) case "SAKICK": client.handleSakick(parts) case "SAPART": client.handleSapart(parts) case "SAJOIN": client.handleSajoin(parts) case "WHOWAS": client.handleWhowas(parts) case "MOTD": client.handleMotd() case "RULES": client.handleRules() case "MAP": client.handleMap() case "KNOCK": client.handleKnock(parts) case "SETNAME": client.handleSetname(parts) case "DIE": client.handleDie() default: client.SendNumeric(ERR_UNKNOWNCOMMAND, command+" :Unknown command") } } // Shutdown gracefully shuts down the server func (s *Server) Shutdown() { log.Println("Initiating graceful shutdown...") // Stop health monitoring first to prevent interference if s.healthMonitor != nil { s.healthMonitor.Stop() } // Close listeners immediately to stop accepting new connections go func() { if s.listener != nil { s.listener.Close() } if s.sslListener != nil { s.sslListener.Close() } if s.serverListener != nil { s.serverListener.Close() } }() // Signal shutdown to all goroutines (non-blocking) select { case <-s.shutdown: // Already closed default: close(s.shutdown) } // Disconnect all linked servers in background go func() { defer func() { if r := recover(); r != nil { log.Printf("Panic during server disconnection: %v", r) } }() s.mu.RLock() linkedServers := make([]*LinkedServer, 0, len(s.linkedServers)) for _, server := range s.linkedServers { linkedServers = append(linkedServers, server) } s.mu.RUnlock() for _, linkedServer := range linkedServers { linkedServer.Disconnect() } }() // Notify and disconnect all clients with timeout go func() { defer func() { if r := recover(); r != nil { log.Printf("Panic during client disconnection: %v", r) } }() // Get client IDs without holding lock for long s.mu.RLock() clientIDs := make([]string, 0, len(s.clients)) for clientID := range s.clients { clientIDs = append(clientIDs, clientID) } s.mu.RUnlock() // Disconnect clients in batches to prevent overwhelming batchSize := 10 for i := 0; i < len(clientIDs); i += batchSize { end := i + batchSize if end > len(clientIDs) { end = len(clientIDs) } batch := clientIDs[i:end] for _, clientID := range batch { s.mu.RLock() client := s.clients[clientID] s.mu.RUnlock() if client != nil { // Send shutdown message with timeout go func(c *Client) { defer func() { if r := recover(); r != nil { log.Printf("Panic notifying client during shutdown: %v", r) } }() c.SendMessage("ERROR :Server shutting down") // Give client time to process message time.Sleep(100 * time.Millisecond) c.ForceDisconnect("Server shutdown") }(client) } } // Small delay between batches time.Sleep(50 * time.Millisecond) } }() // Give everything time to shut down gracefully time.Sleep(2 * time.Second) log.Println("Server shutdown complete") } // Server linking methods // startServerListener starts listening for incoming server connections func (s *Server) startServerListener() { if !s.config.Linking.Enable { return } addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Linking.ServerPort) listener, err := net.Listen("tcp", addr) if err != nil { log.Printf("Failed to start server listener on %s: %v", addr, err) return } s.serverListener = listener log.Printf("IRC server listening for server links on %s", addr) for { select { case <-s.shutdown: return default: conn, err := listener.Accept() if err != nil { continue } log.Printf("Incoming server connection from %s", conn.RemoteAddr()) go s.handleIncomingServer(conn) } } } // startAutoConnections attempts to connect to configured auto-connect servers func (s *Server) startAutoConnections() { if !s.config.Linking.Enable { return } // Wait a bit before starting auto-connections time.Sleep(5 * time.Second) for _, link := range s.config.Linking.Links { if link.AutoConnect { go s.connectToServer(link.Name, link.Host, link.Port, link.Password, link.Hub, link.Description) } } } // connectToServer connects to a remote server func (s *Server) connectToServer(name, host string, port int, password string, hub bool, description string) { linkedServer := NewLinkedServer(name, host, port, password, hub, description, s) s.mu.Lock() s.linkedServers[name] = linkedServer s.mu.Unlock() for { select { case <-s.shutdown: return default: if !linkedServer.IsConnected() { log.Printf("Attempting to connect to server %s at %s:%d", name, host, port) if err := linkedServer.Connect(); err != nil { log.Printf("Failed to connect to server %s: %v", name, err) time.Sleep(30 * time.Second) // Wait before retry continue } log.Printf("Successfully connected to server %s", name) } // Sleep before checking again time.Sleep(60 * time.Second) } } } // handleIncomingServer handles an incoming server connection func (s *Server) handleIncomingServer(conn net.Conn) { defer conn.Close() log.Printf("Handling incoming server connection from %s", conn.RemoteAddr()) // Create a temporary linked server for authentication tempServer := &LinkedServer{ conn: conn, server: s, connected: true, } // Handle the connection (this will process authentication) tempServer.Handle() } // AddLinkedServer adds a linked server to the server func (s *Server) AddLinkedServer(linkedServer *LinkedServer) { s.mu.Lock() defer s.mu.Unlock() s.linkedServers[linkedServer.Name()] = linkedServer } // RemoveLinkedServer removes a linked server func (s *Server) RemoveLinkedServer(name string) { s.mu.Lock() defer s.mu.Unlock() if linkedServer, exists := s.linkedServers[name]; exists { linkedServer.Disconnect() delete(s.linkedServers, name) } } // GetLinkedServer returns a linked server by name func (s *Server) GetLinkedServer(name string) *LinkedServer { s.mu.RLock() defer s.mu.RUnlock() return s.linkedServers[name] } // GetLinkedServers returns all linked servers func (s *Server) GetLinkedServers() map[string]*LinkedServer { s.mu.RLock() defer s.mu.RUnlock() servers := make(map[string]*LinkedServer) for name, server := range s.linkedServers { servers[name] = server } return servers } // BroadcastToServers sends a message to all linked servers func (s *Server) BroadcastToServers(message string) { s.mu.RLock() defer s.mu.RUnlock() for _, linkedServer := range s.linkedServers { if linkedServer.IsConnected() { linkedServer.SendMessage(message) } } } // Ban management functions // parseDuration parses duration strings like "1d", "2h", "30m", "0" (permanent) func parseDuration(durationStr string) time.Duration { if durationStr == "0" { return 0 // Permanent } if len(durationStr) < 2 { return 24 * time.Hour // Default to 1 day } unit := durationStr[len(durationStr)-1] valueStr := durationStr[:len(durationStr)-1] var value int fmt.Sscanf(valueStr, "%d", &value) switch unit { case 's': return time.Duration(value) * time.Second case 'm': return time.Duration(value) * time.Minute case 'h': return time.Duration(value) * time.Hour case 'd': return time.Duration(value) * 24 * time.Hour case 'w': return time.Duration(value) * 7 * 24 * time.Hour default: return 24 * time.Hour // Default to 1 day } }