1017 lines
24 KiB
Go
1017 lines
24 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type Server struct {
|
|
config *Config
|
|
clients map[string]*Client
|
|
channels map[string]*Channel
|
|
listener net.Listener
|
|
sslListener net.Listener
|
|
serverListener net.Listener
|
|
linkedServers map[string]*LinkedServer
|
|
pingMessage string
|
|
mu sync.RWMutex
|
|
shutdown chan bool
|
|
healthMonitor *HealthMonitor
|
|
}
|
|
|
|
func NewServer(config *Config) *Server {
|
|
server := &Server{
|
|
config: config,
|
|
clients: make(map[string]*Client, config.Limits.MaxClients),
|
|
channels: make(map[string]*Channel, config.Limits.MaxChannels),
|
|
shutdown: make(chan bool),
|
|
}
|
|
server.healthMonitor = NewHealthMonitor(server)
|
|
return server
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
// Start regular listener
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.Port)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on %s: %v", addr, err)
|
|
}
|
|
s.listener = listener
|
|
|
|
log.Printf("IRC server listening on %s", addr)
|
|
|
|
// Start health monitoring
|
|
s.healthMonitor.Start()
|
|
|
|
// Start SSL listener if enabled
|
|
if s.config.Server.Listen.EnableSSL {
|
|
go s.startSSLListener()
|
|
}
|
|
|
|
// Start server linking if enabled
|
|
if s.config.Linking.Enable {
|
|
go s.startServerListener()
|
|
go s.startAutoConnections()
|
|
}
|
|
|
|
// Auto-create configured channels
|
|
for _, channelName := range s.config.Channels.AutoJoin {
|
|
channel := NewChannel(channelName)
|
|
// Set default modes
|
|
for _, mode := range s.config.Channels.DefaultModes {
|
|
if mode != '+' {
|
|
channel.SetMode(rune(mode), true)
|
|
}
|
|
}
|
|
s.channels[strings.ToLower(channelName)] = channel
|
|
}
|
|
|
|
// Start ping routine
|
|
go s.pingRoutine()
|
|
|
|
// Accept connections
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return nil
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
client := NewClient(conn, s)
|
|
s.AddClient(client)
|
|
go client.Handle()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) startSSLListener() {
|
|
// Load SSL certificates
|
|
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
|
|
if err != nil {
|
|
log.Printf("Failed to load SSL certificates: %v", err)
|
|
return
|
|
}
|
|
|
|
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
|
|
listener, err := tls.Listen("tcp", addr, tlsConfig)
|
|
if err != nil {
|
|
log.Printf("Failed to start SSL listener on %s: %v", addr, err)
|
|
return
|
|
}
|
|
s.sslListener = listener
|
|
|
|
log.Printf("IRC SSL server listening on %s", addr)
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
client := NewClient(conn, s)
|
|
s.AddClient(client)
|
|
go client.Handle()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) pingRoutine() {
|
|
ticker := time.NewTicker(60 * time.Second) // Reduced frequency to prevent contention
|
|
defer ticker.Stop()
|
|
|
|
// Add a health check ticker that runs less frequently
|
|
healthTicker := time.NewTicker(5 * time.Minute) // Much less frequent health checks
|
|
defer healthTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
case <-ticker.C:
|
|
// Run ping check in a goroutine to prevent blocking
|
|
go s.performPingCheck()
|
|
case <-healthTicker.C:
|
|
// Run health check in a goroutine to prevent blocking
|
|
go s.performHealthCheck()
|
|
}
|
|
}
|
|
}
|
|
|
|
// performPingCheck handles the periodic ping checking
|
|
func (s *Server) performPingCheck() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in performPingCheck: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get a snapshot of clients without holding the lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
// Process clients individually to prevent blocking
|
|
for _, clientID := range clientIDs {
|
|
func() {
|
|
// Get client safely
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client == nil || !client.IsRegistered() || !client.IsConnected() {
|
|
return
|
|
}
|
|
|
|
// Check activity without holding client lock for long
|
|
client.mu.RLock()
|
|
lastActivity := client.lastActivity
|
|
client.mu.RUnlock()
|
|
|
|
// Only send ping if client hasn't been active recently
|
|
if time.Since(lastActivity) > 120*time.Second {
|
|
// Send ping in a non-blocking way
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic sending ping to %s: %v", client.getClientInfo(), r)
|
|
}
|
|
}()
|
|
client.SendMessage(s.pingMessage) // Use pre-computed message
|
|
}()
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
// performHealthCheck checks all clients for health issues
|
|
func (s *Server) performHealthCheck() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic in performHealthCheck: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get a snapshot of clients without holding the lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
totalClients := len(s.clients)
|
|
s.mu.RUnlock()
|
|
|
|
unhealthyClients := 0
|
|
disconnectedClients := []string{} // Store client IDs instead of pointers
|
|
|
|
// Process clients in batches to prevent overwhelming the system
|
|
batchSize := 50
|
|
for i := 0; i < len(clientIDs); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(clientIDs) {
|
|
end = len(clientIDs)
|
|
}
|
|
|
|
batch := clientIDs[i:end]
|
|
for _, clientID := range batch {
|
|
func() {
|
|
// Get client safely
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client == nil {
|
|
return
|
|
}
|
|
|
|
healthy, reason := client.HealthCheck()
|
|
if !healthy {
|
|
unhealthyClients++
|
|
|
|
// Force disconnect clients that are definitely problematic
|
|
if strings.Contains(reason, "disconnected") ||
|
|
strings.Contains(reason, "nil") ||
|
|
strings.Contains(reason, "write errors") ||
|
|
strings.Contains(reason, "registration timeout") {
|
|
disconnectedClients = append(disconnectedClients, clientID)
|
|
log.Printf("Marking client %s for disconnection: %s", client.getClientInfo(), reason)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Small delay between batches to prevent overwhelming
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
// Disconnect problematic clients in a separate goroutine
|
|
if len(disconnectedClients) > 0 {
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during client disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
for _, clientID := range disconnectedClients {
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client != nil {
|
|
client.ForceDisconnect("Connection health check failed")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Log health statistics
|
|
if totalClients > 0 {
|
|
healthyPercentage := float64(totalClients-unhealthyClients) / float64(totalClients) * 100
|
|
log.Printf("Client health check: %d total, %d healthy (%.1f%%), %d marked for disconnection",
|
|
totalClients, totalClients-unhealthyClients, healthyPercentage, len(disconnectedClients))
|
|
}
|
|
|
|
// Log detailed server statistics every 5 minutes for monitoring
|
|
s.LogServerStats()
|
|
}
|
|
|
|
func (s *Server) AddClient(client *Client) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Enhanced validation before adding client
|
|
if client == nil {
|
|
log.Printf("Attempted to add nil client")
|
|
return
|
|
}
|
|
|
|
if client.conn == nil {
|
|
log.Printf("Attempted to add client with nil connection")
|
|
return
|
|
}
|
|
|
|
// Check client limit
|
|
if len(s.clients) >= s.config.Limits.MaxClients {
|
|
log.Printf("Server full, rejecting client from %s", client.getClientInfo())
|
|
client.SendMessage("ERROR :Server full")
|
|
if client.conn != nil {
|
|
client.conn.Close()
|
|
}
|
|
return
|
|
}
|
|
|
|
// Check for duplicate client IDs (shouldn't happen but defensive programming)
|
|
if _, exists := s.clients[client.clientID]; exists {
|
|
log.Printf("Duplicate client ID detected: %s, generating new ID", client.clientID)
|
|
// Generate a new unique ID
|
|
client.clientID = fmt.Sprintf("%s_%d_%d", client.host, time.Now().Unix(), len(s.clients))
|
|
}
|
|
|
|
s.clients[client.clientID] = client
|
|
log.Printf("Added client %s (total clients: %d/%d)", client.getClientInfo(), len(s.clients), s.config.Limits.MaxClients)
|
|
}
|
|
|
|
func (s *Server) RemoveClient(client *Client) {
|
|
s.mu.Lock()
|
|
delete(s.clients, client.clientID)
|
|
s.mu.Unlock()
|
|
|
|
// Send snomask notification for client disconnect (after releasing the lock)
|
|
if client.IsRegistered() {
|
|
s.sendSnomask('c', fmt.Sprintf("Client disconnect: %s (%s@%s)",
|
|
client.Nick(), client.User(), client.Host()))
|
|
}
|
|
}
|
|
|
|
// GetServerStats returns detailed server statistics for monitoring
|
|
func (s *Server) GetServerStats() map[string]interface{} {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
stats := make(map[string]interface{})
|
|
|
|
// Basic counts
|
|
stats["total_clients"] = len(s.clients)
|
|
stats["total_channels"] = len(s.channels)
|
|
|
|
// Client statistics
|
|
registeredClients := 0
|
|
operatorClients := 0
|
|
sslClients := 0
|
|
unhealthyClients := 0
|
|
|
|
for _, client := range s.clients {
|
|
if client.IsRegistered() {
|
|
registeredClients++
|
|
}
|
|
if client.IsOper() {
|
|
operatorClients++
|
|
}
|
|
if client.IsSSL() {
|
|
sslClients++
|
|
}
|
|
|
|
// Quick health check
|
|
if healthy, _ := client.HealthCheck(); !healthy {
|
|
unhealthyClients++
|
|
}
|
|
}
|
|
|
|
stats["registered_clients"] = registeredClients
|
|
stats["operator_clients"] = operatorClients
|
|
stats["ssl_clients"] = sslClients
|
|
stats["unhealthy_clients"] = unhealthyClients
|
|
|
|
// Channel statistics
|
|
totalChannelUsers := 0
|
|
for _, channel := range s.channels {
|
|
totalChannelUsers += len(channel.GetClients())
|
|
}
|
|
stats["total_channel_users"] = totalChannelUsers
|
|
|
|
// Server linking statistics
|
|
stats["linked_servers"] = len(s.linkedServers)
|
|
|
|
return stats
|
|
}
|
|
|
|
// LogServerStats logs current server statistics
|
|
func (s *Server) LogServerStats() {
|
|
stats := s.GetServerStats()
|
|
log.Printf("Server Statistics: %d clients (%d registered, %d operators, %d SSL, %d unhealthy), %d channels, %d linked servers",
|
|
stats["total_clients"], stats["registered_clients"], stats["operator_clients"],
|
|
stats["ssl_clients"], stats["unhealthy_clients"], stats["total_channels"], stats["linked_servers"])
|
|
}
|
|
|
|
// sendSnomask sends a server notice to operators watching a specific snomask
|
|
func (s *Server) sendSnomask(snomask rune, message string) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if client.IsOper() && client.HasSnomask(snomask) {
|
|
client.SendMessage(fmt.Sprintf(":%s NOTICE %s :*** %s",
|
|
s.config.Server.Name, client.Nick(), message))
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReloadConfig reloads the server configuration
|
|
func (s *Server) ReloadConfig() error {
|
|
config, err := LoadConfig("config.json")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load config: %v", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.config = config
|
|
s.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClient(nick string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if strings.EqualFold(client.Nick(), nick) {
|
|
return client
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClientByHost(host string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if client.Host() == host {
|
|
return client
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) GetClientByID(clientID string) *Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.clients[clientID]
|
|
}
|
|
|
|
func (s *Server) GetClients() map[string]*Client {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
clients := make(map[string]*Client)
|
|
for clientID, client := range s.clients {
|
|
clients[clientID] = client
|
|
}
|
|
return clients
|
|
}
|
|
|
|
func (s *Server) GetChannel(name string) *Channel {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.channels[strings.ToLower(name)]
|
|
}
|
|
|
|
func (s *Server) GetOrCreateChannel(name string) *Channel {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
channelName := strings.ToLower(name)
|
|
if channel, exists := s.channels[channelName]; exists {
|
|
return channel
|
|
}
|
|
|
|
// Create new channel
|
|
channel := NewChannel(name)
|
|
// Set default modes
|
|
for _, mode := range s.config.Channels.DefaultModes {
|
|
if mode != '+' {
|
|
channel.SetMode(rune(mode), true)
|
|
}
|
|
}
|
|
s.channels[channelName] = channel
|
|
return channel
|
|
}
|
|
|
|
func (s *Server) RemoveChannel(name string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.channels, strings.ToLower(name))
|
|
}
|
|
|
|
func (s *Server) GetChannels() map[string]*Channel {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
channels := make(map[string]*Channel)
|
|
for name, channel := range s.channels {
|
|
channels[name] = channel
|
|
}
|
|
return channels
|
|
}
|
|
|
|
func (s *Server) CreateChannel(name string) *Channel {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if len(s.channels) >= s.config.Limits.MaxChannels {
|
|
return nil
|
|
}
|
|
|
|
channel := NewChannel(name)
|
|
s.channels[strings.ToLower(name)] = channel
|
|
return channel
|
|
}
|
|
|
|
func (s *Server) GetClientCount() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.clients)
|
|
}
|
|
|
|
func (s *Server) GetChannelCount() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.channels)
|
|
}
|
|
|
|
func (s *Server) IsNickInUse(nick string) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
if strings.EqualFold(client.Nick(), nick) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Server) HandleMessage(client *Client, message string) {
|
|
// Parse IRCv3 message tags if present
|
|
var tags map[string]string
|
|
var actualMessage string
|
|
|
|
if strings.HasPrefix(message, "@") {
|
|
// Message has IRCv3 tags
|
|
spaceIndex := strings.Index(message, " ")
|
|
if spaceIndex == -1 {
|
|
// Message is only tags (like typing indicators), this is valid - just ignore
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
|
|
}
|
|
return
|
|
}
|
|
|
|
tagString := message[1:spaceIndex] // Remove @ prefix
|
|
actualMessage = strings.TrimSpace(message[spaceIndex+1:])
|
|
|
|
// Parse tags
|
|
tags = make(map[string]string)
|
|
tagPairs := strings.Split(tagString, ";")
|
|
for _, pair := range tagPairs {
|
|
if strings.Contains(pair, "=") {
|
|
kv := strings.SplitN(pair, "=", 2)
|
|
tags[kv[0]] = kv[1]
|
|
} else {
|
|
tags[pair] = ""
|
|
}
|
|
}
|
|
|
|
// If actualMessage is empty or just a colon, it's a tags-only message
|
|
if actualMessage == "" || actualMessage == ":" {
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
|
|
}
|
|
return
|
|
}
|
|
} else {
|
|
actualMessage = message
|
|
}
|
|
|
|
parts := strings.Fields(actualMessage)
|
|
if len(parts) == 0 {
|
|
return
|
|
}
|
|
|
|
command := strings.ToUpper(parts[0])
|
|
|
|
// Log the command for debugging
|
|
if DebugMode {
|
|
log.Printf("<<< RECV from %s: %s", client.Host(), message)
|
|
} else {
|
|
log.Printf("Client %s: %s", client.Host(), message)
|
|
}
|
|
|
|
switch command {
|
|
case "CAP":
|
|
client.handleCap(parts)
|
|
case "NICK":
|
|
client.handleNick(parts)
|
|
case "USER":
|
|
client.handleUser(parts)
|
|
case "PING":
|
|
client.handlePing(parts)
|
|
case "PONG":
|
|
client.handlePong(parts)
|
|
case "JOIN":
|
|
client.handleJoin(parts)
|
|
case "PART":
|
|
client.handlePart(parts)
|
|
case "PRIVMSG":
|
|
client.handlePrivmsg(parts)
|
|
case "NOTICE":
|
|
client.handleNotice(parts)
|
|
case "TAGMSG":
|
|
client.handleTagmsg(parts, tags)
|
|
case "WHO":
|
|
client.handleWho(parts)
|
|
case "WHOIS":
|
|
client.handleWhois(parts)
|
|
case "NAMES":
|
|
client.handleNames(parts)
|
|
case "MODE":
|
|
client.handleMode(parts)
|
|
case "OPER":
|
|
client.handleOper(parts)
|
|
case "SNOMASK":
|
|
client.handleSnomask(parts)
|
|
case "GLOBALNOTICE":
|
|
client.handleGlobalNotice(parts)
|
|
case "OPERWALL":
|
|
client.handleOperWall(parts)
|
|
case "WALLOPS":
|
|
client.handleWallops(parts)
|
|
case "REHASH":
|
|
client.handleRehash()
|
|
case "TRACE":
|
|
client.handleTrace(parts)
|
|
case "HELPOP":
|
|
client.handleHelpop(parts)
|
|
case "TOPIC":
|
|
client.handleTopic(parts)
|
|
case "KICK":
|
|
client.handleKick(parts)
|
|
case "INVITE":
|
|
client.handleInvite(parts)
|
|
case "AWAY":
|
|
client.handleAway(parts)
|
|
case "LIST":
|
|
client.handleList()
|
|
case "KILL":
|
|
client.handleKill(parts)
|
|
case "QUIT":
|
|
client.handleQuit(parts)
|
|
case "CONNECT":
|
|
client.handleConnect(parts)
|
|
case "SQUIT":
|
|
client.handleSquit(parts)
|
|
case "LINKS":
|
|
client.handleLinks()
|
|
case "USERHOST":
|
|
client.handleUserhost(parts)
|
|
case "ISON":
|
|
client.handleIson(parts)
|
|
case "TIME":
|
|
client.handleTime()
|
|
case "VERSION":
|
|
client.handleVersion()
|
|
case "ADMIN":
|
|
client.handleAdmin()
|
|
case "INFO":
|
|
client.handleInfo()
|
|
case "LUSERS":
|
|
client.handleLusers()
|
|
case "STATS":
|
|
client.handleStats(parts)
|
|
case "SILENCE":
|
|
client.handleSilence(parts)
|
|
case "MONITOR":
|
|
client.handleMonitor(parts)
|
|
case "AUTHENTICATE":
|
|
client.handleAuthenticate(parts)
|
|
// Services/Admin Commands
|
|
case "CHGHOST":
|
|
client.handleChghost(parts)
|
|
case "SVSNICK":
|
|
client.handleSvsnick(parts)
|
|
case "SVSMODE":
|
|
client.handleSvsmode(parts)
|
|
case "SAMODE":
|
|
client.handleSamode(parts)
|
|
case "SANICK":
|
|
client.handleSanick(parts)
|
|
case "SAKICK":
|
|
client.handleSakick(parts)
|
|
case "SAPART":
|
|
client.handleSapart(parts)
|
|
case "SAJOIN":
|
|
client.handleSajoin(parts)
|
|
case "WHOWAS":
|
|
client.handleWhowas(parts)
|
|
case "MOTD":
|
|
client.handleMotd()
|
|
case "RULES":
|
|
client.handleRules()
|
|
case "MAP":
|
|
client.handleMap()
|
|
case "KNOCK":
|
|
client.handleKnock(parts)
|
|
case "SETNAME":
|
|
client.handleSetname(parts)
|
|
case "DIE":
|
|
client.handleDie()
|
|
default:
|
|
client.SendNumeric(ERR_UNKNOWNCOMMAND, command+" :Unknown command")
|
|
}
|
|
}
|
|
|
|
// Shutdown gracefully shuts down the server
|
|
func (s *Server) Shutdown() {
|
|
log.Println("Initiating graceful shutdown...")
|
|
|
|
// Stop health monitoring first to prevent interference
|
|
if s.healthMonitor != nil {
|
|
s.healthMonitor.Stop()
|
|
}
|
|
|
|
// Close listeners immediately to stop accepting new connections
|
|
go func() {
|
|
if s.listener != nil {
|
|
s.listener.Close()
|
|
}
|
|
if s.sslListener != nil {
|
|
s.sslListener.Close()
|
|
}
|
|
if s.serverListener != nil {
|
|
s.serverListener.Close()
|
|
}
|
|
}()
|
|
|
|
// Signal shutdown to all goroutines (non-blocking)
|
|
select {
|
|
case <-s.shutdown:
|
|
// Already closed
|
|
default:
|
|
close(s.shutdown)
|
|
}
|
|
|
|
// Disconnect all linked servers in background
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during server disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
s.mu.RLock()
|
|
linkedServers := make([]*LinkedServer, 0, len(s.linkedServers))
|
|
for _, server := range s.linkedServers {
|
|
linkedServers = append(linkedServers, server)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
for _, linkedServer := range linkedServers {
|
|
linkedServer.Disconnect()
|
|
}
|
|
}()
|
|
|
|
// Notify and disconnect all clients with timeout
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic during client disconnection: %v", r)
|
|
}
|
|
}()
|
|
|
|
// Get client IDs without holding lock for long
|
|
s.mu.RLock()
|
|
clientIDs := make([]string, 0, len(s.clients))
|
|
for clientID := range s.clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
s.mu.RUnlock()
|
|
|
|
// Disconnect clients in batches to prevent overwhelming
|
|
batchSize := 10
|
|
for i := 0; i < len(clientIDs); i += batchSize {
|
|
end := i + batchSize
|
|
if end > len(clientIDs) {
|
|
end = len(clientIDs)
|
|
}
|
|
|
|
batch := clientIDs[i:end]
|
|
for _, clientID := range batch {
|
|
s.mu.RLock()
|
|
client := s.clients[clientID]
|
|
s.mu.RUnlock()
|
|
|
|
if client != nil {
|
|
// Send shutdown message with timeout
|
|
go func(c *Client) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("Panic notifying client during shutdown: %v", r)
|
|
}
|
|
}()
|
|
c.SendMessage("ERROR :Server shutting down")
|
|
// Give client time to process message
|
|
time.Sleep(100 * time.Millisecond)
|
|
c.ForceDisconnect("Server shutdown")
|
|
}(client)
|
|
}
|
|
}
|
|
|
|
// Small delay between batches
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
}()
|
|
|
|
// Give everything time to shut down gracefully
|
|
time.Sleep(2 * time.Second)
|
|
|
|
log.Println("Server shutdown complete")
|
|
}
|
|
|
|
// Server linking methods
|
|
|
|
// startServerListener starts listening for incoming server connections
|
|
func (s *Server) startServerListener() {
|
|
if !s.config.Linking.Enable {
|
|
return
|
|
}
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Linking.ServerPort)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
log.Printf("Failed to start server listener on %s: %v", addr, err)
|
|
return
|
|
}
|
|
|
|
s.serverListener = listener
|
|
log.Printf("IRC server listening for server links on %s", addr)
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
log.Printf("Incoming server connection from %s", conn.RemoteAddr())
|
|
go s.handleIncomingServer(conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
// startAutoConnections attempts to connect to configured auto-connect servers
|
|
func (s *Server) startAutoConnections() {
|
|
if !s.config.Linking.Enable {
|
|
return
|
|
}
|
|
|
|
// Wait a bit before starting auto-connections
|
|
time.Sleep(5 * time.Second)
|
|
|
|
for _, link := range s.config.Linking.Links {
|
|
if link.AutoConnect {
|
|
go s.connectToServer(link.Name, link.Host, link.Port, link.Password, link.Hub, link.Description)
|
|
}
|
|
}
|
|
}
|
|
|
|
// connectToServer connects to a remote server
|
|
func (s *Server) connectToServer(name, host string, port int, password string, hub bool, description string) {
|
|
linkedServer := NewLinkedServer(name, host, port, password, hub, description, s)
|
|
|
|
s.mu.Lock()
|
|
s.linkedServers[name] = linkedServer
|
|
s.mu.Unlock()
|
|
|
|
for {
|
|
select {
|
|
case <-s.shutdown:
|
|
return
|
|
default:
|
|
if !linkedServer.IsConnected() {
|
|
log.Printf("Attempting to connect to server %s at %s:%d", name, host, port)
|
|
if err := linkedServer.Connect(); err != nil {
|
|
log.Printf("Failed to connect to server %s: %v", name, err)
|
|
time.Sleep(30 * time.Second) // Wait before retry
|
|
continue
|
|
}
|
|
log.Printf("Successfully connected to server %s", name)
|
|
}
|
|
|
|
// Sleep before checking again
|
|
time.Sleep(60 * time.Second)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleIncomingServer handles an incoming server connection
|
|
func (s *Server) handleIncomingServer(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
log.Printf("Handling incoming server connection from %s", conn.RemoteAddr())
|
|
|
|
// Create a temporary linked server for authentication
|
|
tempServer := &LinkedServer{
|
|
conn: conn,
|
|
server: s,
|
|
connected: true,
|
|
}
|
|
|
|
// Handle the connection (this will process authentication)
|
|
tempServer.Handle()
|
|
}
|
|
|
|
// AddLinkedServer adds a linked server to the server
|
|
func (s *Server) AddLinkedServer(linkedServer *LinkedServer) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.linkedServers[linkedServer.Name()] = linkedServer
|
|
}
|
|
|
|
// RemoveLinkedServer removes a linked server
|
|
func (s *Server) RemoveLinkedServer(name string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if linkedServer, exists := s.linkedServers[name]; exists {
|
|
linkedServer.Disconnect()
|
|
delete(s.linkedServers, name)
|
|
}
|
|
}
|
|
|
|
// GetLinkedServer returns a linked server by name
|
|
func (s *Server) GetLinkedServer(name string) *LinkedServer {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.linkedServers[name]
|
|
}
|
|
|
|
// GetLinkedServers returns all linked servers
|
|
func (s *Server) GetLinkedServers() map[string]*LinkedServer {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
servers := make(map[string]*LinkedServer)
|
|
for name, server := range s.linkedServers {
|
|
servers[name] = server
|
|
}
|
|
return servers
|
|
}
|
|
|
|
// BroadcastToServers sends a message to all linked servers
|
|
func (s *Server) BroadcastToServers(message string) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
for _, linkedServer := range s.linkedServers {
|
|
if linkedServer.IsConnected() {
|
|
linkedServer.SendMessage(message)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ban management functions
|
|
|
|
// parseDuration parses duration strings like "1d", "2h", "30m", "0" (permanent)
|
|
func parseDuration(durationStr string) time.Duration {
|
|
if durationStr == "0" {
|
|
return 0 // Permanent
|
|
}
|
|
|
|
if len(durationStr) < 2 {
|
|
return 24 * time.Hour // Default to 1 day
|
|
}
|
|
|
|
unit := durationStr[len(durationStr)-1]
|
|
valueStr := durationStr[:len(durationStr)-1]
|
|
|
|
var value int
|
|
fmt.Sscanf(valueStr, "%d", &value)
|
|
|
|
switch unit {
|
|
case 's':
|
|
return time.Duration(value) * time.Second
|
|
case 'm':
|
|
return time.Duration(value) * time.Minute
|
|
case 'h':
|
|
return time.Duration(value) * time.Hour
|
|
case 'd':
|
|
return time.Duration(value) * 24 * time.Hour
|
|
case 'w':
|
|
return time.Duration(value) * 7 * 24 * time.Hour
|
|
default:
|
|
return 24 * time.Hour // Default to 1 day
|
|
}
|
|
}
|
|
|