Fix critical bugs and security vulnerabilities
- Fix race condition in client cleanup by serializing operations - Add proper nil checks in SendMessage for server/config - Add semaphore to limit concurrent health check goroutines - Reduce buffer size to RFC-compliant 512 bytes (was 4096) - Add comprehensive input validation (length, null bytes, UTF-8) - Improve SSL error handling with graceful degradation - Replace unsafe conn.Close() with proper cleanup() calls - Prevent goroutine leaks and memory exhaustion attacks - Enhanced logging and error recovery throughout These fixes address the freezing issues and improve overall server stability, security, and RFC compliance.
This commit is contained in:
69
client.go
69
client.go
@@ -177,12 +177,22 @@ func (c *Client) SendMessage(message string) {
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
// Enhanced connection health check
|
// Enhanced connection and server health check
|
||||||
if c.conn == nil {
|
if c.conn == nil {
|
||||||
log.Printf("SendMessage: connection is nil for client %s", c.Nick())
|
log.Printf("SendMessage: connection is nil for client %s", c.Nick())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.server == nil {
|
||||||
|
log.Printf("SendMessage: server is nil for client %s", c.Nick())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.server.config == nil {
|
||||||
|
log.Printf("SendMessage: server config is nil for client %s", c.Nick())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Validate message before sending
|
// Validate message before sending
|
||||||
if message == "" {
|
if message == "" {
|
||||||
return
|
return
|
||||||
@@ -958,8 +968,8 @@ func (c *Client) Handle() {
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(c.conn)
|
scanner := bufio.NewScanner(c.conn)
|
||||||
|
|
||||||
// Set maximum line length to prevent memory exhaustion
|
// Set maximum line length per IRC RFC (512 bytes including CRLF)
|
||||||
const maxLineLength = 4096
|
const maxLineLength = 512
|
||||||
scanner.Buffer(make([]byte, maxLineLength), maxLineLength)
|
scanner.Buffer(make([]byte, maxLineLength), maxLineLength)
|
||||||
|
|
||||||
// Set initial read deadline - be more generous during connection setup
|
// Set initial read deadline - be more generous during connection setup
|
||||||
@@ -1035,38 +1045,37 @@ func (c *Client) cleanup() {
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Part all channels with error handling in a separate goroutine to prevent blocking
|
// Perform cleanup operations sequentially to avoid race conditions
|
||||||
go func() {
|
defer func() {
|
||||||
defer func() {
|
if r := recover(); r != nil {
|
||||||
if r := recover(); r != nil {
|
log.Printf("Panic during cleanup for %s: %v", c.getClientInfo(), r)
|
||||||
log.Printf("Panic during channel cleanup for %s: %v", c.getClientInfo(), r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
channels := c.GetChannels()
|
|
||||||
for channelName, channel := range channels {
|
|
||||||
if channel != nil {
|
|
||||||
channel.RemoveClient(c)
|
|
||||||
// Clean up empty channels
|
|
||||||
if len(channel.GetClients()) == 0 && c.server != nil {
|
|
||||||
c.server.RemoveChannel(channelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Remove from server in a separate goroutine to prevent deadlock
|
// Get channels snapshot before cleanup
|
||||||
go func() {
|
channels := c.GetChannels()
|
||||||
defer func() {
|
var emptyChannels []string
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Printf("Panic during server cleanup for %s: %v", c.getClientInfo(), r)
|
// Remove client from all channels first
|
||||||
|
for channelName, channel := range channels {
|
||||||
|
if channel != nil {
|
||||||
|
channel.RemoveClient(c)
|
||||||
|
// Track empty channels for later cleanup
|
||||||
|
if len(channel.GetClients()) == 0 {
|
||||||
|
emptyChannels = append(emptyChannels, channelName)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
if c.server != nil {
|
|
||||||
c.server.RemoveClient(c)
|
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
|
// Remove from server (must happen after channel cleanup)
|
||||||
|
if c.server != nil {
|
||||||
|
c.server.RemoveClient(c)
|
||||||
|
|
||||||
|
// Clean up empty channels after client removal to prevent race conditions
|
||||||
|
for _, channelName := range emptyChannels {
|
||||||
|
c.server.RemoveChannel(channelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Printf("Cleanup completed for client %s", c.getClientInfo())
|
log.Printf("Cleanup completed for client %s", c.getClientInfo())
|
||||||
}
|
}
|
||||||
|
|||||||
11
commands.go
11
commands.go
@@ -1096,11 +1096,8 @@ func (c *Client) handleQuit(parts []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove client from server
|
// Use proper cleanup instead of direct connection close
|
||||||
c.server.RemoveClient(c)
|
c.cleanup()
|
||||||
|
|
||||||
// Close the connection
|
|
||||||
c.conn.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleMode handles MODE command
|
// handleMode handles MODE command
|
||||||
@@ -2080,8 +2077,8 @@ func (c *Client) handleKill(parts []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect the target
|
// Disconnect the target properly
|
||||||
target.conn.Close()
|
target.cleanup()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleOper handles OPER command
|
// handleOper handles OPER command
|
||||||
|
|||||||
74
server.go
74
server.go
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -22,14 +23,16 @@ type Server struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
shutdown chan bool
|
shutdown chan bool
|
||||||
healthMonitor *HealthMonitor
|
healthMonitor *HealthMonitor
|
||||||
|
healthCheckSem chan struct{} // Semaphore to limit concurrent health checks
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(config *Config) *Server {
|
func NewServer(config *Config) *Server {
|
||||||
server := &Server{
|
server := &Server{
|
||||||
config: config,
|
config: config,
|
||||||
clients: make(map[string]*Client, config.Limits.MaxClients),
|
clients: make(map[string]*Client, config.Limits.MaxClients),
|
||||||
channels: make(map[string]*Channel, config.Limits.MaxChannels),
|
channels: make(map[string]*Channel, config.Limits.MaxChannels),
|
||||||
shutdown: make(chan bool),
|
shutdown: make(chan bool),
|
||||||
|
healthCheckSem: make(chan struct{}, 1), // Only allow 1 concurrent health check
|
||||||
}
|
}
|
||||||
server.healthMonitor = NewHealthMonitor(server)
|
server.healthMonitor = NewHealthMonitor(server)
|
||||||
return server
|
return server
|
||||||
@@ -94,10 +97,17 @@ func (s *Server) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startSSLListener() {
|
func (s *Server) startSSLListener() {
|
||||||
|
// Validate SSL configuration
|
||||||
|
if s.config.Server.SSL.CertFile == "" || s.config.Server.SSL.KeyFile == "" {
|
||||||
|
log.Printf("SSL enabled but certificate or key file not specified")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Load SSL certificates
|
// Load SSL certificates
|
||||||
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
|
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to load SSL certificates: %v", err)
|
log.Printf("CRITICAL: Failed to load SSL certificates (SSL disabled): %v", err)
|
||||||
|
log.Printf("SSL will not be available. Check certificate paths in config.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,12 +116,13 @@ func (s *Server) startSSLListener() {
|
|||||||
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
|
||||||
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to start SSL listener on %s: %v", addr, err)
|
log.Printf("CRITICAL: Failed to start SSL listener on %s: %v", addr, err)
|
||||||
|
log.Printf("SSL will not be available. Server continues with plain connections only.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.sslListener = listener
|
s.sslListener = listener
|
||||||
|
|
||||||
log.Printf("IRC SSL server listening on %s", addr)
|
log.Printf("✓ IRC SSL server listening on %s", addr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -143,11 +154,29 @@ func (s *Server) pingRoutine() {
|
|||||||
case <-s.shutdown:
|
case <-s.shutdown:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Run ping check in a goroutine to prevent blocking
|
// Run ping check with semaphore to prevent goroutine leaks
|
||||||
go s.performPingCheck()
|
select {
|
||||||
|
case s.healthCheckSem <- struct{}{}:
|
||||||
|
go func() {
|
||||||
|
defer func() { <-s.healthCheckSem }()
|
||||||
|
s.performPingCheck()
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
// Skip ping check if one is already running
|
||||||
|
log.Printf("Skipping ping check - one already in progress")
|
||||||
|
}
|
||||||
case <-healthTicker.C:
|
case <-healthTicker.C:
|
||||||
// Run health check in a goroutine to prevent blocking
|
// Run health check with semaphore to prevent goroutine leaks
|
||||||
go s.performHealthCheck()
|
select {
|
||||||
|
case s.healthCheckSem <- struct{}{}:
|
||||||
|
go func() {
|
||||||
|
defer func() { <-s.healthCheckSem }()
|
||||||
|
s.performHealthCheck()
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
// Skip health check if one is already running
|
||||||
|
log.Printf("Skipping health check - one already in progress")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -551,6 +580,29 @@ func (s *Server) IsNickInUse(nick string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) HandleMessage(client *Client, message string) {
|
func (s *Server) HandleMessage(client *Client, message string) {
|
||||||
|
// Input validation
|
||||||
|
if len(message) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for maximum message length (IRC RFC limits to 512 bytes)
|
||||||
|
if len(message) > 510 { // 510 to account for CRLF
|
||||||
|
log.Printf("Message too long from %s (%d bytes), dropping", client.Host(), len(message))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for null bytes (not allowed in IRC messages)
|
||||||
|
if strings.ContainsRune(message, '\x00') {
|
||||||
|
log.Printf("Message contains null bytes from %s, dropping", client.Host())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Basic UTF-8 validation
|
||||||
|
if !utf8.ValidString(message) {
|
||||||
|
log.Printf("Message contains invalid UTF-8 from %s, dropping", client.Host())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Parse IRCv3 message tags if present
|
// Parse IRCv3 message tags if present
|
||||||
var tags map[string]string
|
var tags map[string]string
|
||||||
var actualMessage string
|
var actualMessage string
|
||||||
|
|||||||
Reference in New Issue
Block a user