Fix critical bugs and security vulnerabilities

- Fix race condition in client cleanup by serializing operations
- Add proper nil checks in SendMessage for server/config
- Add semaphore to limit concurrent health check goroutines
- Reduce buffer size to RFC-compliant 512 bytes (was 4096)
- Add comprehensive input validation (length, null bytes, UTF-8)
- Improve SSL error handling with graceful degradation
- Replace unsafe conn.Close() with proper cleanup() calls
- Prevent goroutine leaks and memory exhaustion attacks
- Enhanced logging and error recovery throughout

These fixes address the freezing issues and improve overall
server stability, security, and RFC compliance.
This commit is contained in:
2025-09-27 15:13:55 +01:00
parent 6772bfd842
commit bab403557f
3 changed files with 106 additions and 48 deletions

View File

@@ -177,12 +177,22 @@ func (c *Client) SendMessage(message string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
// Enhanced connection health check // Enhanced connection and server health check
if c.conn == nil { if c.conn == nil {
log.Printf("SendMessage: connection is nil for client %s", c.Nick()) log.Printf("SendMessage: connection is nil for client %s", c.Nick())
return return
} }
if c.server == nil {
log.Printf("SendMessage: server is nil for client %s", c.Nick())
return
}
if c.server.config == nil {
log.Printf("SendMessage: server config is nil for client %s", c.Nick())
return
}
// Validate message before sending // Validate message before sending
if message == "" { if message == "" {
return return
@@ -958,8 +968,8 @@ func (c *Client) Handle() {
scanner := bufio.NewScanner(c.conn) scanner := bufio.NewScanner(c.conn)
// Set maximum line length to prevent memory exhaustion // Set maximum line length per IRC RFC (512 bytes including CRLF)
const maxLineLength = 4096 const maxLineLength = 512
scanner.Buffer(make([]byte, maxLineLength), maxLineLength) scanner.Buffer(make([]byte, maxLineLength), maxLineLength)
// Set initial read deadline - be more generous during connection setup // Set initial read deadline - be more generous during connection setup
@@ -1035,38 +1045,37 @@ func (c *Client) cleanup() {
c.mu.Unlock() c.mu.Unlock()
} }
// Part all channels with error handling in a separate goroutine to prevent blocking // Perform cleanup operations sequentially to avoid race conditions
go func() { defer func() {
defer func() { if r := recover(); r != nil {
if r := recover(); r != nil { log.Printf("Panic during cleanup for %s: %v", c.getClientInfo(), r)
log.Printf("Panic during channel cleanup for %s: %v", c.getClientInfo(), r)
}
}()
channels := c.GetChannels()
for channelName, channel := range channels {
if channel != nil {
channel.RemoveClient(c)
// Clean up empty channels
if len(channel.GetClients()) == 0 && c.server != nil {
c.server.RemoveChannel(channelName)
}
}
} }
}() }()
// Remove from server in a separate goroutine to prevent deadlock // Get channels snapshot before cleanup
go func() { channels := c.GetChannels()
defer func() { var emptyChannels []string
if r := recover(); r != nil {
log.Printf("Panic during server cleanup for %s: %v", c.getClientInfo(), r)
}
}()
if c.server != nil { // Remove client from all channels first
c.server.RemoveClient(c) for channelName, channel := range channels {
if channel != nil {
channel.RemoveClient(c)
// Track empty channels for later cleanup
if len(channel.GetClients()) == 0 {
emptyChannels = append(emptyChannels, channelName)
}
} }
}() }
// Remove from server (must happen after channel cleanup)
if c.server != nil {
c.server.RemoveClient(c)
// Clean up empty channels after client removal to prevent race conditions
for _, channelName := range emptyChannels {
c.server.RemoveChannel(channelName)
}
}
log.Printf("Cleanup completed for client %s", c.getClientInfo()) log.Printf("Cleanup completed for client %s", c.getClientInfo())
} }

View File

@@ -1096,11 +1096,8 @@ func (c *Client) handleQuit(parts []string) {
} }
} }
// Remove client from server // Use proper cleanup instead of direct connection close
c.server.RemoveClient(c) c.cleanup()
// Close the connection
c.conn.Close()
} }
// handleMode handles MODE command // handleMode handles MODE command
@@ -2080,8 +2077,8 @@ func (c *Client) handleKill(parts []string) {
} }
} }
// Disconnect the target // Disconnect the target properly
target.conn.Close() target.cleanup()
} }
// handleOper handles OPER command // handleOper handles OPER command

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"unicode/utf8"
) )
type Server struct { type Server struct {
@@ -22,14 +23,16 @@ type Server struct {
mu sync.RWMutex mu sync.RWMutex
shutdown chan bool shutdown chan bool
healthMonitor *HealthMonitor healthMonitor *HealthMonitor
healthCheckSem chan struct{} // Semaphore to limit concurrent health checks
} }
func NewServer(config *Config) *Server { func NewServer(config *Config) *Server {
server := &Server{ server := &Server{
config: config, config: config,
clients: make(map[string]*Client, config.Limits.MaxClients), clients: make(map[string]*Client, config.Limits.MaxClients),
channels: make(map[string]*Channel, config.Limits.MaxChannels), channels: make(map[string]*Channel, config.Limits.MaxChannels),
shutdown: make(chan bool), shutdown: make(chan bool),
healthCheckSem: make(chan struct{}, 1), // Only allow 1 concurrent health check
} }
server.healthMonitor = NewHealthMonitor(server) server.healthMonitor = NewHealthMonitor(server)
return server return server
@@ -94,10 +97,17 @@ func (s *Server) Start() error {
} }
func (s *Server) startSSLListener() { func (s *Server) startSSLListener() {
// Validate SSL configuration
if s.config.Server.SSL.CertFile == "" || s.config.Server.SSL.KeyFile == "" {
log.Printf("SSL enabled but certificate or key file not specified")
return
}
// Load SSL certificates // Load SSL certificates
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile) cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
if err != nil { if err != nil {
log.Printf("Failed to load SSL certificates: %v", err) log.Printf("CRITICAL: Failed to load SSL certificates (SSL disabled): %v", err)
log.Printf("SSL will not be available. Check certificate paths in config.")
return return
} }
@@ -106,12 +116,13 @@ func (s *Server) startSSLListener() {
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort) addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
listener, err := tls.Listen("tcp", addr, tlsConfig) listener, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil { if err != nil {
log.Printf("Failed to start SSL listener on %s: %v", addr, err) log.Printf("CRITICAL: Failed to start SSL listener on %s: %v", addr, err)
log.Printf("SSL will not be available. Server continues with plain connections only.")
return return
} }
s.sslListener = listener s.sslListener = listener
log.Printf("IRC SSL server listening on %s", addr) log.Printf("IRC SSL server listening on %s", addr)
for { for {
select { select {
@@ -143,11 +154,29 @@ func (s *Server) pingRoutine() {
case <-s.shutdown: case <-s.shutdown:
return return
case <-ticker.C: case <-ticker.C:
// Run ping check in a goroutine to prevent blocking // Run ping check with semaphore to prevent goroutine leaks
go s.performPingCheck() select {
case s.healthCheckSem <- struct{}{}:
go func() {
defer func() { <-s.healthCheckSem }()
s.performPingCheck()
}()
default:
// Skip ping check if one is already running
log.Printf("Skipping ping check - one already in progress")
}
case <-healthTicker.C: case <-healthTicker.C:
// Run health check in a goroutine to prevent blocking // Run health check with semaphore to prevent goroutine leaks
go s.performHealthCheck() select {
case s.healthCheckSem <- struct{}{}:
go func() {
defer func() { <-s.healthCheckSem }()
s.performHealthCheck()
}()
default:
// Skip health check if one is already running
log.Printf("Skipping health check - one already in progress")
}
} }
} }
} }
@@ -551,6 +580,29 @@ func (s *Server) IsNickInUse(nick string) bool {
} }
func (s *Server) HandleMessage(client *Client, message string) { func (s *Server) HandleMessage(client *Client, message string) {
// Input validation
if len(message) == 0 {
return
}
// Check for maximum message length (IRC RFC limits to 512 bytes)
if len(message) > 510 { // 510 to account for CRLF
log.Printf("Message too long from %s (%d bytes), dropping", client.Host(), len(message))
return
}
// Check for null bytes (not allowed in IRC messages)
if strings.ContainsRune(message, '\x00') {
log.Printf("Message contains null bytes from %s, dropping", client.Host())
return
}
// Basic UTF-8 validation
if !utf8.ValidString(message) {
log.Printf("Message contains invalid UTF-8 from %s, dropping", client.Host())
return
}
// Parse IRCv3 message tags if present // Parse IRCv3 message tags if present
var tags map[string]string var tags map[string]string
var actualMessage string var actualMessage string