- Fix race condition in client cleanup by serializing operations - Add proper nil checks in SendMessage for server/config - Add semaphore to limit concurrent health check goroutines - Reduce buffer size to RFC-compliant 512 bytes (was 4096) - Add comprehensive input validation (length, null bytes, UTF-8) - Improve SSL error handling with graceful degradation - Replace unsafe conn.Close() with proper cleanup() calls - Prevent goroutine leaks and memory exhaustion attacks - Enhanced logging and error recovery throughout These fixes address the freezing issues and improve overall server stability, security, and RFC compliance.
1069 lines
25 KiB
Go
1069 lines
25 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
type Server struct {
|
|
config *Config
|
|
clients map[string]*Client
|
|
channels map[string]*Channel
|
|
listener net.Listener
|
|
sslListener net.Listener
|
|
serverListener net.Listener
|
|
linkedServers map[string]*LinkedServer
|
|
pingMessage string
|
|
mu sync.RWMutex
|
|
shutdown chan bool
|
|
healthMonitor *HealthMonitor
|
|
healthCheckSem chan struct{} // Semaphore to limit concurrent health checks
|
|
}
|
|
|
|
func NewServer(config *Config) *Server {
|
|
server := &Server{
|
|
config: config,
|
|
clients: make(map[string]*Client, config.Limits.MaxClients),
|
|
channels: make(map[string]*Channel, config.Limits.MaxChannels),
|
|
shutdown: make(chan bool),
|
|
healthCheckSem: make(chan struct{}, 1), // Only allow 1 concurrent health check
|
|
}
|
|
server.healthMonitor = NewHealthMonitor(server)
|
|
return server
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
// Start regular listener
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.Port)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on %s: %v", addr, err)
|
|
}
|
|
s.listener = listener
|
|
|
|
log.Printf("IRC server listening on %s", addr)
|
|
|
|
// Start health monitoring
|
|
s.healthMonitor.Start()
|
|
|
|
// Start SSL listener if enabled
|
|
if s.config.Server.Listen.EnableSSL {
|
|
go s.startSSLListener()
|
|
}
|
|
|
|
// Start server linking if enabled
|
|
if s.config.Linking.Enable {
|
|
go s.startServerListener()
|
|
go s.startAutoConnections()
|
|
}
|
|
|
|
// Auto-create configured channels
|
|
for _, channelName := range s.config.Channels.AutoJoin {
|
|
channel := NewChannel(channelName)
|
|
// Set default modes
|
|
for _, mode := range s.config.Channels.DefaultModes {
|
|
if mode != '+' {
|
|
channel.SetMode(rune(mode), true)
|
|
}
|
|
}
|
|
s.channels[strings.ToLower(channelName)] = channel
|
|
}
|
|
|
|
// Start ping routine
|
|
go s.pingRoutine()
|
|
|
|
// Accept connections
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return nil
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
client := NewClient(conn, s)
|
|
s.AddClient(client)
|
|
go client.Handle()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) startSSLListener() {
|
|
// Validate SSL configuration
|
|
if s.config.Server.SSL.CertFile == "" || s.config.Server.SSL.KeyFile == "" {
|
|
log.Printf("SSL enabled but certificate or key file not specified")
|
|
return
|
|
}
|
|
|
|
// Load SSL certificates
|
|
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
|
|
if err != nil {
|
|
log.Printf("CRITICAL: Failed to load SSL certificates (SSL disabled): %v", err)
|
|
log.Printf("SSL will not be available. Check certificate paths in config.")
|
|
return
|
|
}
|
|
|
|
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
|
|
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
|
if err != nil {
|
|
log.Printf("CRITICAL: Failed to start SSL listener on %s: %v", addr, err)
|
|
log.Printf("SSL will not be available. Server continues with plain connections only.")
|
|
return
|
|
}
|
|
s.sslListener = listener
|
|
|
|
log.Printf("✓ IRC SSL server listening on %s", addr)
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
client := NewClient(conn, s)
|
|
s.AddClient(client)
|
|
go client.Handle()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) pingRoutine() {
|
|
ticker := time.NewTicker(60 * time.Second) // Reduced frequency to prevent contention
|
|
defer ticker.Stop()
|
|
|
|
// Add a health check ticker that runs less frequently
|
|
healthTicker := time.NewTicker(5 * time.Minute) // Much less frequent health checks
|
|
defer healthTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
case <-ticker.C:
|
|
// Run ping check with semaphore to prevent goroutine leaks
|
|
select {
|
|
case s.healthCheckSem <- struct{}{}:
|
|
go func() {
|
|
defer func() { <-s.healthCheckSem }()
|
|
s.performPingCheck()
|
|
}()
|
|
default:
|
|
// Skip ping check if one is already running
|
|
log.Printf("Skipping ping check - one already in progress")
|
|
}
|
|
case <-healthTicker.C:
|
|
// Run health check with semaphore to prevent goroutine leaks
|
|
select {
|
|
case s.healthCheckSem <- struct{}{}:
|
|
go func() {
|
|
defer func() { <-s.healthCheckSem }()
|
|
s.performHealthCheck()
|
|
}()
|
|
default:
|
|
// Skip health check if one is already running
|
|
log.Printf("Skipping health check - one already in progress")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// performPingCheck handles the periodic ping checking
|
|
func (s *Server) performPingCheck() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in performPingCheck: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get a snapshot of clients without holding the lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
// Process clients individually to prevent blocking
|
|
for _, clientID := range clientIDs {
|
|
func() {
|
|
// Get client safely
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client == nil || !client.IsRegistered() || !client.IsConnected() {
|
|
return
|
|
}
|
|
|
|
// Check activity without holding client lock for long
|
|
client.mu.RLock()
|
|
lastActivity := client.lastActivity
|
|
client.mu.RUnlock()
|
|
|
|
// Only send ping if client hasn't been active recently
|
|
if time.Since(lastActivity) > 120*time.Second {
|
|
// Send ping in a non-blocking way
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic sending ping to %s: %v", client.getClientInfo(), r)
|
|
}
|
|
}()
|
|
client.SendMessage(s.pingMessage) // Use pre-computed message
|
|
}()
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
// performHealthCheck checks all clients for health issues
|
|
func (s *Server) performHealthCheck() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in performHealthCheck: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get a snapshot of clients without holding the lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
totalClients := len(s.clients)
|
|
s.mu.RUnlock()
|
|
|
|
unhealthyClients := 0
|
|
disconnectedClients := []string{} // Store client IDs instead of pointers
|
|
|
|
// Process clients in batches to prevent overwhelming the system
|
|
batchSize := 50
|
|
for i := 0; i < len(clientIDs); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(clientIDs) {
|
|
end = len(clientIDs)
|
|
}
|
|
|
|
batch := clientIDs[i:end]
|
|
for _, clientID := range batch {
|
|
func() {
|
|
// Get client safely
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client == nil {
|
|
return
|
|
}
|
|
|
|
healthy, reason := client.HealthCheck()
|
|
if !healthy {
|
|
unhealthyClients++
|
|
|
|
// Force disconnect clients that are definitely problematic
|
|
if strings.Contains(reason, "disconnected") ||
|
|
strings.Contains(reason, "nil") ||
|
|
strings.Contains(reason, "write errors") ||
|
|
strings.Contains(reason, "registration timeout") {
|
|
disconnectedClients = append(disconnectedClients, clientID)
|
|
log.Printf("Marking client %s for disconnection: %s", client.getClientInfo(), reason)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Small delay between batches to prevent overwhelming
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
// Disconnect problematic clients in a separate goroutine
|
|
if len(disconnectedClients) > 0 {
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during client disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
for _, clientID := range disconnectedClients {
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client != nil {
|
|
client.ForceDisconnect("Connection health check failed")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Log health statistics
|
|
if totalClients > 0 {
|
|
healthyPercentage := float64(totalClients-unhealthyClients) / float64(totalClients) * 100
|
|
log.Printf("Client health check: %d total, %d healthy (%.1f%%), %d marked for disconnection",
|
|
totalClients, totalClients-unhealthyClients, healthyPercentage, len(disconnectedClients))
|
|
}
|
|
|
|
// Log detailed server statistics every 5 minutes for monitoring
|
|
s.LogServerStats()
|
|
}
|
|
|
|
func (s *Server) AddClient(client *Client) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Enhanced validation before adding client
|
|
if client == nil {
|
|
log.Printf("Attempted to add nil client")
|
|
return
|
|
}
|
|
|
|
if client.conn == nil {
|
|
log.Printf("Attempted to add client with nil connection")
|
|
return
|
|
}
|
|
|
|
// Check client limit
|
|
if len(s.clients) >= s.config.Limits.MaxClients {
|
|
log.Printf("Server full, rejecting client from %s", client.getClientInfo())
|
|
client.SendMessage("ERROR :Server full")
|
|
if client.conn != nil {
|
|
client.conn.Close()
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check for duplicate client IDs (shouldn't happen but defensive programming)
|
|
if _, exists := s.clients[client.clientID]; exists {
|
|
log.Printf("Duplicate client ID detected: %s, generating new ID", client.clientID)
|
|
// Generate a new unique ID
|
|
client.clientID = fmt.Sprintf("%s_%d_%d", client.host, time.Now().Unix(), len(s.clients))
|
|
}
|
|
|
|
s.clients[client.clientID] = client
|
|
log.Printf("Added client %s (total clients: %d/%d)", client.getClientInfo(), len(s.clients), s.config.Limits.MaxClients)
|
|
}
|
|
|
|
func (s *Server) RemoveClient(client *Client) {
|
|
s.mu.Lock()
|
|
delete(s.clients, client.clientID)
|
|
s.mu.Unlock()
|
|
|
|
// Send snomask notification for client disconnect (after releasing the lock)
|
|
if client.IsRegistered() {
|
|
s.sendSnomask('c', fmt.Sprintf("Client disconnect: %s (%s@%s)",
|
|
client.Nick(), client.User(), client.Host()))
|
|
}
|
|
}
|
|
|
|
// GetServerStats returns detailed server statistics for monitoring
|
|
func (s *Server) GetServerStats() map[string]interface{} {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
stats := make(map[string]interface{})
|
|
|
|
// Basic counts
|
|
stats["total_clients"] = len(s.clients)
|
|
stats["total_channels"] = len(s.channels)
|
|
|
|
// Client statistics
|
|
registeredClients := 0
|
|
operatorClients := 0
|
|
sslClients := 0
|
|
unhealthyClients := 0
|
|
|
|
for _, client := range s.clients {
|
|
if client.IsRegistered() {
|
|
registeredClients++
|
|
}
|
|
if client.IsOper() {
|
|
operatorClients++
|
|
}
|
|
if client.IsSSL() {
|
|
sslClients++
|
|
}
|
|
|
|
// Quick health check
|
|
if healthy, _ := client.HealthCheck(); !healthy {
|
|
unhealthyClients++
|
|
}
|
|
}
|
|
|
|
stats["registered_clients"] = registeredClients
|
|
stats["operator_clients"] = operatorClients
|
|
stats["ssl_clients"] = sslClients
|
|
stats["unhealthy_clients"] = unhealthyClients
|
|
|
|
// Channel statistics
|
|
totalChannelUsers := 0
|
|
for _, channel := range s.channels {
|
|
totalChannelUsers += len(channel.GetClients())
|
|
}
|
|
stats["total_channel_users"] = totalChannelUsers
|
|
|
|
// Server linking statistics
|
|
stats["linked_servers"] = len(s.linkedServers)
|
|
|
|
return stats
|
|
}
|
|
|
|
// LogServerStats logs current server statistics
|
|
func (s *Server) LogServerStats() {
|
|
stats := s.GetServerStats()
|
|
log.Printf("Server Statistics: %d clients (%d registered, %d operators, %d SSL, %d unhealthy), %d channels, %d linked servers",
|
|
stats["total_clients"], stats["registered_clients"], stats["operator_clients"],
|
|
stats["ssl_clients"], stats["unhealthy_clients"], stats["total_channels"], stats["linked_servers"])
|
|
}
|
|
|
|
// sendSnomask sends a server notice to operators watching a specific snomask
|
|
func (s *Server) sendSnomask(snomask rune, message string) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if client.IsOper() && client.HasSnomask(snomask) {
|
|
client.SendMessage(fmt.Sprintf(":%s NOTICE %s :*** %s",
|
|
s.config.Server.Name, client.Nick(), message))
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReloadConfig reloads the server configuration
|
|
func (s *Server) ReloadConfig() error {
|
|
config, err := LoadConfig("config.json")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load config: %v", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.config = config
|
|
s.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClient(nick string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if strings.EqualFold(client.Nick(), nick) {
|
|
return client
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClientByHost(host string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if client.Host() == host {
|
|
return client
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClientByID(clientID string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.clients[clientID]
|
|
}
|
|
|
|
func (s *Server) GetClients() map[string]*Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
clients := make(map[string]*Client)
|
|
for clientID, client := range s.clients {
|
|
clients[clientID] = client
|
|
}
|
|
return clients
|
|
}
|
|
|
|
func (s *Server) GetChannel(name string) *Channel {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.channels[strings.ToLower(name)]
|
|
}
|
|
|
|
func (s *Server) GetOrCreateChannel(name string) *Channel {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
channelName := strings.ToLower(name)
|
|
if channel, exists := s.channels[channelName]; exists {
|
|
return channel
|
|
}
|
|
|
|
// Create new channel
|
|
channel := NewChannel(name)
|
|
// Set default modes
|
|
for _, mode := range s.config.Channels.DefaultModes {
|
|
if mode != '+' {
|
|
channel.SetMode(rune(mode), true)
|
|
}
|
|
}
|
|
s.channels[channelName] = channel
|
|
return channel
|
|
}
|
|
|
|
func (s *Server) RemoveChannel(name string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.channels, strings.ToLower(name))
|
|
}
|
|
|
|
func (s *Server) GetChannels() map[string]*Channel {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
channels := make(map[string]*Channel)
|
|
for name, channel := range s.channels {
|
|
channels[name] = channel
|
|
}
|
|
return channels
|
|
}
|
|
|
|
func (s *Server) CreateChannel(name string) *Channel {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if len(s.channels) >= s.config.Limits.MaxChannels {
|
|
return nil
|
|
}
|
|
|
|
channel := NewChannel(name)
|
|
s.channels[strings.ToLower(name)] = channel
|
|
return channel
|
|
}
|
|
|
|
func (s *Server) GetClientCount() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.clients)
|
|
}
|
|
|
|
func (s *Server) GetChannelCount() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.channels)
|
|
}
|
|
|
|
func (s *Server) IsNickInUse(nick string) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if strings.EqualFold(client.Nick(), nick) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Server) HandleMessage(client *Client, message string) {
|
|
// Input validation
|
|
if len(message) == 0 {
|
|
return
|
|
}
|
|
|
|
// Check for maximum message length (IRC RFC limits to 512 bytes)
|
|
if len(message) > 510 { // 510 to account for CRLF
|
|
log.Printf("Message too long from %s (%d bytes), dropping", client.Host(), len(message))
|
|
return
|
|
}
|
|
|
|
// Check for null bytes (not allowed in IRC messages)
|
|
if strings.ContainsRune(message, '\x00') {
|
|
log.Printf("Message contains null bytes from %s, dropping", client.Host())
|
|
return
|
|
}
|
|
|
|
// Basic UTF-8 validation
|
|
if !utf8.ValidString(message) {
|
|
log.Printf("Message contains invalid UTF-8 from %s, dropping", client.Host())
|
|
return
|
|
}
|
|
|
|
// Parse IRCv3 message tags if present
|
|
var tags map[string]string
|
|
var actualMessage string
|
|
|
|
if strings.HasPrefix(message, "@") {
|
|
// Message has IRCv3 tags
|
|
spaceIndex := strings.Index(message, " ")
|
|
if spaceIndex == -1 {
|
|
// Message is only tags (like typing indicators), this is valid - just ignore
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
|
|
}
|
|
return
|
|
}
|
|
|
|
tagString := message[1:spaceIndex] // Remove @ prefix
|
|
actualMessage = strings.TrimSpace(message[spaceIndex+1:])
|
|
|
|
// Parse tags
|
|
tags = make(map[string]string)
|
|
tagPairs := strings.Split(tagString, ";")
|
|
for _, pair := range tagPairs {
|
|
if strings.Contains(pair, "=") {
|
|
kv := strings.SplitN(pair, "=", 2)
|
|
tags[kv[0]] = kv[1]
|
|
} else {
|
|
tags[pair] = ""
|
|
}
|
|
}
|
|
|
|
// If actualMessage is empty or just a colon, it's a tags-only message
|
|
if actualMessage == "" || actualMessage == ":" {
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
|
|
}
|
|
return
|
|
}
|
|
} else {
|
|
actualMessage = message
|
|
}
|
|
|
|
parts := strings.Fields(actualMessage)
|
|
if len(parts) == 0 {
|
|
return
|
|
}
|
|
|
|
command := strings.ToUpper(parts[0])
|
|
|
|
// Log the command for debugging
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s", client.Host(), message)
|
|
} else {
|
|
log.Printf("Client %s: %s", client.Host(), message)
|
|
}
|
|
|
|
switch command {
|
|
case "CAP":
|
|
client.handleCap(parts)
|
|
case "NICK":
|
|
client.handleNick(parts)
|
|
case "USER":
|
|
client.handleUser(parts)
|
|
case "PING":
|
|
client.handlePing(parts)
|
|
case "PONG":
|
|
client.handlePong(parts)
|
|
case "JOIN":
|
|
client.handleJoin(parts)
|
|
case "PART":
|
|
client.handlePart(parts)
|
|
case "PRIVMSG":
|
|
client.handlePrivmsg(parts)
|
|
case "NOTICE":
|
|
client.handleNotice(parts)
|
|
case "TAGMSG":
|
|
client.handleTagmsg(parts, tags)
|
|
case "WHO":
|
|
client.handleWho(parts)
|
|
case "WHOIS":
|
|
client.handleWhois(parts)
|
|
case "NAMES":
|
|
client.handleNames(parts)
|
|
case "MODE":
|
|
client.handleMode(parts)
|
|
case "OPER":
|
|
client.handleOper(parts)
|
|
case "SNOMASK":
|
|
client.handleSnomask(parts)
|
|
case "GLOBALNOTICE":
|
|
client.handleGlobalNotice(parts)
|
|
case "OPERWALL":
|
|
client.handleOperWall(parts)
|
|
case "WALLOPS":
|
|
client.handleWallops(parts)
|
|
case "REHASH":
|
|
client.handleRehash()
|
|
case "TRACE":
|
|
client.handleTrace(parts)
|
|
case "HELPOP":
|
|
client.handleHelpop(parts)
|
|
case "TOPIC":
|
|
client.handleTopic(parts)
|
|
case "KICK":
|
|
client.handleKick(parts)
|
|
case "INVITE":
|
|
client.handleInvite(parts)
|
|
case "AWAY":
|
|
client.handleAway(parts)
|
|
case "LIST":
|
|
client.handleList()
|
|
case "KILL":
|
|
client.handleKill(parts)
|
|
case "QUIT":
|
|
client.handleQuit(parts)
|
|
case "CONNECT":
|
|
client.handleConnect(parts)
|
|
case "SQUIT":
|
|
client.handleSquit(parts)
|
|
case "LINKS":
|
|
client.handleLinks()
|
|
case "USERHOST":
|
|
client.handleUserhost(parts)
|
|
case "ISON":
|
|
client.handleIson(parts)
|
|
case "TIME":
|
|
client.handleTime()
|
|
case "VERSION":
|
|
client.handleVersion()
|
|
case "ADMIN":
|
|
client.handleAdmin()
|
|
case "INFO":
|
|
client.handleInfo()
|
|
case "LUSERS":
|
|
client.handleLusers()
|
|
case "STATS":
|
|
client.handleStats(parts)
|
|
case "SILENCE":
|
|
client.handleSilence(parts)
|
|
case "MONITOR":
|
|
client.handleMonitor(parts)
|
|
case "AUTHENTICATE":
|
|
client.handleAuthenticate(parts)
|
|
// Services/Admin Commands
|
|
case "CHGHOST":
|
|
client.handleChghost(parts)
|
|
case "SVSNICK":
|
|
client.handleSvsnick(parts)
|
|
case "SVSMODE":
|
|
client.handleSvsmode(parts)
|
|
case "SAMODE":
|
|
client.handleSamode(parts)
|
|
case "SANICK":
|
|
client.handleSanick(parts)
|
|
case "SAKICK":
|
|
client.handleSakick(parts)
|
|
case "SAPART":
|
|
client.handleSapart(parts)
|
|
case "SAJOIN":
|
|
client.handleSajoin(parts)
|
|
case "WHOWAS":
|
|
client.handleWhowas(parts)
|
|
case "MOTD":
|
|
client.handleMotd()
|
|
case "RULES":
|
|
client.handleRules()
|
|
case "MAP":
|
|
client.handleMap()
|
|
case "KNOCK":
|
|
client.handleKnock(parts)
|
|
case "SETNAME":
|
|
client.handleSetname(parts)
|
|
case "DIE":
|
|
client.handleDie()
|
|
default:
|
|
client.SendNumeric(ERR_UNKNOWNCOMMAND, command+" :Unknown command")
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the server
|
|
func (s *Server) Shutdown() {
|
|
log.Println("Initiating graceful shutdown...")
|
|
|
|
// Stop health monitoring first to prevent interference
|
|
if s.healthMonitor != nil {
|
|
s.healthMonitor.Stop()
|
|
}
|
|
|
|
// Close listeners immediately to stop accepting new connections
|
|
go func() {
|
|
if s.listener != nil {
|
|
s.listener.Close()
|
|
}
|
|
if s.sslListener != nil {
|
|
s.sslListener.Close()
|
|
}
|
|
if s.serverListener != nil {
|
|
s.serverListener.Close()
|
|
}
|
|
}()
|
|
|
|
// Signal shutdown to all goroutines (non-blocking)
|
|
select {
|
|
case <-s.shutdown:
|
|
// Already closed
|
|
default:
|
|
close(s.shutdown)
|
|
}
|
|
|
|
// Disconnect all linked servers in background
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during server disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
s.mu.RLock()
|
|
linkedServers := make([]*LinkedServer, 0, len(s.linkedServers))
|
|
for _, server := range s.linkedServers {
|
|
linkedServers = append(linkedServers, server)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
for _, linkedServer := range linkedServers {
|
|
linkedServer.Disconnect()
|
|
}
|
|
}()
|
|
|
|
// Notify and disconnect all clients with timeout
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during client disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get client IDs without holding lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
// Disconnect clients in batches to prevent overwhelming
|
|
batchSize := 10
|
|
for i := 0; i < len(clientIDs); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(clientIDs) {
|
|
end = len(clientIDs)
|
|
}
|
|
|
|
batch := clientIDs[i:end]
|
|
for _, clientID := range batch {
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client != nil {
|
|
// Send shutdown message with timeout
|
|
go func(c *Client) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic notifying client during shutdown: %v", r)
|
|
}
|
|
}()
|
|
c.SendMessage("ERROR :Server shutting down")
|
|
// Give client time to process message
|
|
time.Sleep(100 * time.Millisecond)
|
|
c.ForceDisconnect("Server shutdown")
|
|
}(client)
|
|
}
|
|
}
|
|
|
|
// Small delay between batches
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
}()
|
|
|
|
// Give everything time to shut down gracefully
|
|
time.Sleep(2 * time.Second)
|
|
|
|
log.Println("Server shutdown complete")
|
|
}
|
|
|
|
// Server linking methods
|
|
|
|
// startServerListener starts listening for incoming server connections
|
|
func (s *Server) startServerListener() {
|
|
if !s.config.Linking.Enable {
|
|
return
|
|
}
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Linking.ServerPort)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
log.Printf("Failed to start server listener on %s: %v", addr, err)
|
|
return
|
|
}
|
|
|
|
s.serverListener = listener
|
|
log.Printf("IRC server listening for server links on %s", addr)
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
log.Printf("Incoming server connection from %s", conn.RemoteAddr())
|
|
go s.handleIncomingServer(conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
// startAutoConnections attempts to connect to configured auto-connect servers
|
|
func (s *Server) startAutoConnections() {
|
|
if !s.config.Linking.Enable {
|
|
return
|
|
}
|
|
|
|
// Wait a bit before starting auto-connections
|
|
time.Sleep(5 * time.Second)
|
|
|
|
for _, link := range s.config.Linking.Links {
|
|
if link.AutoConnect {
|
|
go s.connectToServer(link.Name, link.Host, link.Port, link.Password, link.Hub, link.Description)
|
|
}
|
|
}
|
|
}
|
|
|
|
// connectToServer connects to a remote server
|
|
func (s *Server) connectToServer(name, host string, port int, password string, hub bool, description string) {
|
|
linkedServer := NewLinkedServer(name, host, port, password, hub, description, s)
|
|
|
|
s.mu.Lock()
|
|
s.linkedServers[name] = linkedServer
|
|
s.mu.Unlock()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
if !linkedServer.IsConnected() {
|
|
log.Printf("Attempting to connect to server %s at %s:%d", name, host, port)
|
|
if err := linkedServer.Connect(); err != nil {
|
|
log.Printf("Failed to connect to server %s: %v", name, err)
|
|
time.Sleep(30 * time.Second) // Wait before retry
|
|
continue
|
|
}
|
|
log.Printf("Successfully connected to server %s", name)
|
|
}
|
|
|
|
// Sleep before checking again
|
|
time.Sleep(60 * time.Second)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleIncomingServer handles an incoming server connection
|
|
func (s *Server) handleIncomingServer(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
log.Printf("Handling incoming server connection from %s", conn.RemoteAddr())
|
|
|
|
// Create a temporary linked server for authentication
|
|
tempServer := &LinkedServer{
|
|
conn: conn,
|
|
server: s,
|
|
connected: true,
|
|
}
|
|
|
|
// Handle the connection (this will process authentication)
|
|
tempServer.Handle()
|
|
}
|
|
|
|
// AddLinkedServer adds a linked server to the server
|
|
func (s *Server) AddLinkedServer(linkedServer *LinkedServer) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.linkedServers[linkedServer.Name()] = linkedServer
|
|
}
|
|
|
|
// RemoveLinkedServer removes a linked server
|
|
func (s *Server) RemoveLinkedServer(name string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if linkedServer, exists := s.linkedServers[name]; exists {
|
|
linkedServer.Disconnect()
|
|
delete(s.linkedServers, name)
|
|
}
|
|
}
|
|
|
|
// GetLinkedServer returns a linked server by name
|
|
func (s *Server) GetLinkedServer(name string) *LinkedServer {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.linkedServers[name]
|
|
}
|
|
|
|
// GetLinkedServers returns all linked servers
|
|
func (s *Server) GetLinkedServers() map[string]*LinkedServer {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
servers := make(map[string]*LinkedServer)
|
|
for name, server := range s.linkedServers {
|
|
servers[name] = server
|
|
}
|
|
return servers
|
|
}
|
|
|
|
// BroadcastToServers sends a message to all linked servers
|
|
func (s *Server) BroadcastToServers(message string) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, linkedServer := range s.linkedServers {
|
|
if linkedServer.IsConnected() {
|
|
linkedServer.SendMessage(message)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ban management functions
|
|
|
|
// parseDuration parses duration strings like "1d", "2h", "30m", "0" (permanent)
|
|
func parseDuration(durationStr string) time.Duration {
|
|
if durationStr == "0" {
|
|
return 0 // Permanent
|
|
}
|
|
|
|
if len(durationStr) < 2 {
|
|
return 24 * time.Hour // Default to 1 day
|
|
}
|
|
|
|
unit := durationStr[len(durationStr)-1]
|
|
valueStr := durationStr[:len(durationStr)-1]
|
|
|
|
var value int
|
|
fmt.Sscanf(valueStr, "%d", &value)
|
|
|
|
switch unit {
|
|
case 's':
|
|
return time.Duration(value) * time.Second
|
|
case 'm':
|
|
return time.Duration(value) * time.Minute
|
|
case 'h':
|
|
return time.Duration(value) * time.Hour
|
|
case 'd':
|
|
return time.Duration(value) * 24 * time.Hour
|
|
case 'w':
|
|
return time.Duration(value) * 7 * 24 * time.Hour
|
|
default:
|
|
return 24 * time.Hour // Default to 1 day
|
|
}
|
|
}
|
|
|