diff --git a/client.go b/client.go index 120683c..7e961a0 100644 --- a/client.go +++ b/client.go @@ -177,12 +177,22 @@ func (c *Client) SendMessage(message string) { c.mu.Lock() defer c.mu.Unlock() - // Enhanced connection health check + // Enhanced connection and server health check if c.conn == nil { log.Printf("SendMessage: connection is nil for client %s", c.Nick()) return } + if c.server == nil { + log.Printf("SendMessage: server is nil for client %s", c.Nick()) + return + } + + if c.server.config == nil { + log.Printf("SendMessage: server config is nil for client %s", c.Nick()) + return + } + // Validate message before sending if message == "" { return @@ -958,8 +968,8 @@ func (c *Client) Handle() { scanner := bufio.NewScanner(c.conn) - // Set maximum line length to prevent memory exhaustion - const maxLineLength = 4096 + // Set maximum line length per IRC RFC (512 bytes including CRLF) + const maxLineLength = 512 scanner.Buffer(make([]byte, maxLineLength), maxLineLength) // Set initial read deadline - be more generous during connection setup @@ -1035,38 +1045,37 @@ func (c *Client) cleanup() { c.mu.Unlock() } - // Part all channels with error handling in a separate goroutine to prevent blocking - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic during channel cleanup for %s: %v", c.getClientInfo(), r) - } - }() - - channels := c.GetChannels() - for channelName, channel := range channels { - if channel != nil { - channel.RemoveClient(c) - // Clean up empty channels - if len(channel.GetClients()) == 0 && c.server != nil { - c.server.RemoveChannel(channelName) - } - } + // Perform cleanup operations sequentially to avoid race conditions + defer func() { + if r := recover(); r != nil { + log.Printf("Panic during cleanup for %s: %v", c.getClientInfo(), r) } }() - // Remove from server in a separate goroutine to prevent deadlock - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic during server cleanup for %s: %v", c.getClientInfo(), r) + // Get channels snapshot before cleanup + channels := c.GetChannels() + var emptyChannels []string + + // Remove client from all channels first + for channelName, channel := range channels { + if channel != nil { + channel.RemoveClient(c) + // Track empty channels for later cleanup + if len(channel.GetClients()) == 0 { + emptyChannels = append(emptyChannels, channelName) } - }() - - if c.server != nil { - c.server.RemoveClient(c) } - }() + } + + // Remove from server (must happen after channel cleanup) + if c.server != nil { + c.server.RemoveClient(c) + + // Clean up empty channels after client removal to prevent race conditions + for _, channelName := range emptyChannels { + c.server.RemoveChannel(channelName) + } + } log.Printf("Cleanup completed for client %s", c.getClientInfo()) } diff --git a/commands.go b/commands.go index 2afb7e6..682c6ef 100644 --- a/commands.go +++ b/commands.go @@ -1096,11 +1096,8 @@ func (c *Client) handleQuit(parts []string) { } } - // Remove client from server - c.server.RemoveClient(c) - - // Close the connection - c.conn.Close() + // Use proper cleanup instead of direct connection close + c.cleanup() } // handleMode handles MODE command @@ -2080,8 +2077,8 @@ func (c *Client) handleKill(parts []string) { } } - // Disconnect the target - target.conn.Close() + // Disconnect the target properly + target.cleanup() } // handleOper handles OPER command diff --git a/server.go b/server.go index a90e7b3..de6706b 100644 --- a/server.go +++ b/server.go @@ -8,6 +8,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" ) type Server struct { @@ -22,14 +23,16 @@ type Server struct { mu sync.RWMutex shutdown chan bool healthMonitor *HealthMonitor + healthCheckSem chan struct{} // Semaphore to limit concurrent health checks } func NewServer(config *Config) *Server { server := &Server{ - config: config, - clients: make(map[string]*Client, config.Limits.MaxClients), - channels: make(map[string]*Channel, config.Limits.MaxChannels), - shutdown: make(chan bool), + config: config, + clients: make(map[string]*Client, config.Limits.MaxClients), + channels: make(map[string]*Channel, config.Limits.MaxChannels), + shutdown: make(chan bool), + healthCheckSem: make(chan struct{}, 1), // Only allow 1 concurrent health check } server.healthMonitor = NewHealthMonitor(server) return server @@ -94,10 +97,17 @@ func (s *Server) Start() error { } func (s *Server) startSSLListener() { + // Validate SSL configuration + if s.config.Server.SSL.CertFile == "" || s.config.Server.SSL.KeyFile == "" { + log.Printf("SSL enabled but certificate or key file not specified") + return + } + // Load SSL certificates cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile) if err != nil { - log.Printf("Failed to load SSL certificates: %v", err) + log.Printf("CRITICAL: Failed to load SSL certificates (SSL disabled): %v", err) + log.Printf("SSL will not be available. Check certificate paths in config.") return } @@ -106,12 +116,13 @@ func (s *Server) startSSLListener() { addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort) listener, err := tls.Listen("tcp", addr, tlsConfig) if err != nil { - log.Printf("Failed to start SSL listener on %s: %v", addr, err) + log.Printf("CRITICAL: Failed to start SSL listener on %s: %v", addr, err) + log.Printf("SSL will not be available. Server continues with plain connections only.") return } s.sslListener = listener - log.Printf("IRC SSL server listening on %s", addr) + log.Printf("✓ IRC SSL server listening on %s", addr) for { select { @@ -143,11 +154,29 @@ func (s *Server) pingRoutine() { case <-s.shutdown: return case <-ticker.C: - // Run ping check in a goroutine to prevent blocking - go s.performPingCheck() + // Run ping check with semaphore to prevent goroutine leaks + select { + case s.healthCheckSem <- struct{}{}: + go func() { + defer func() { <-s.healthCheckSem }() + s.performPingCheck() + }() + default: + // Skip ping check if one is already running + log.Printf("Skipping ping check - one already in progress") + } case <-healthTicker.C: - // Run health check in a goroutine to prevent blocking - go s.performHealthCheck() + // Run health check with semaphore to prevent goroutine leaks + select { + case s.healthCheckSem <- struct{}{}: + go func() { + defer func() { <-s.healthCheckSem }() + s.performHealthCheck() + }() + default: + // Skip health check if one is already running + log.Printf("Skipping health check - one already in progress") + } } } } @@ -551,6 +580,29 @@ func (s *Server) IsNickInUse(nick string) bool { } func (s *Server) HandleMessage(client *Client, message string) { + // Input validation + if len(message) == 0 { + return + } + + // Check for maximum message length (IRC RFC limits to 512 bytes) + if len(message) > 510 { // 510 to account for CRLF + log.Printf("Message too long from %s (%d bytes), dropping", client.Host(), len(message)) + return + } + + // Check for null bytes (not allowed in IRC messages) + if strings.ContainsRune(message, '\x00') { + log.Printf("Message contains null bytes from %s, dropping", client.Host()) + return + } + + // Basic UTF-8 validation + if !utf8.ValidString(message) { + log.Printf("Message contains invalid UTF-8 from %s, dropping", client.Host()) + return + } + // Parse IRCv3 message tags if present var tags map[string]string var actualMessage string