Files
techircd/server.go

1017 lines
24 KiB
Go

package main
import (
"crypto/tls"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
)
type Server struct {
config *Config
clients map[string]*Client
channels map[string]*Channel
listener net.Listener
sslListener net.Listener
serverListener net.Listener
linkedServers map[string]*LinkedServer
pingMessage string
mu sync.RWMutex
shutdown chan bool
healthMonitor *HealthMonitor
}
func NewServer(config *Config) *Server {
server := &Server{
config: config,
clients: make(map[string]*Client, config.Limits.MaxClients),
channels: make(map[string]*Channel, config.Limits.MaxChannels),
shutdown: make(chan bool),
}
server.healthMonitor = NewHealthMonitor(server)
return server
}
func (s *Server) Start() error {
// Start regular listener
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.Port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %v", addr, err)
}
s.listener = listener
log.Printf("IRC server listening on %s", addr)
// Start health monitoring
s.healthMonitor.Start()
// Start SSL listener if enabled
if s.config.Server.Listen.EnableSSL {
go s.startSSLListener()
}
// Start server linking if enabled
if s.config.Linking.Enable {
go s.startServerListener()
go s.startAutoConnections()
}
// Auto-create configured channels
for _, channelName := range s.config.Channels.AutoJoin {
channel := NewChannel(channelName)
// Set default modes
for _, mode := range s.config.Channels.DefaultModes {
if mode != '+' {
channel.SetMode(rune(mode), true)
}
}
s.channels[strings.ToLower(channelName)] = channel
}
// Start ping routine
go s.pingRoutine()
// Accept connections
for {
select {
case <-s.shutdown:
return nil
default:
conn, err := listener.Accept()
if err != nil {
continue
}
client := NewClient(conn, s)
s.AddClient(client)
go client.Handle()
}
}
}
func (s *Server) startSSLListener() {
// Load SSL certificates
cert, err := tls.LoadX509KeyPair(s.config.Server.SSL.CertFile, s.config.Server.SSL.KeyFile)
if err != nil {
log.Printf("Failed to load SSL certificates: %v", err)
return
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Server.Listen.SSLPort)
listener, err := tls.Listen("tcp", addr, tlsConfig)
if err != nil {
log.Printf("Failed to start SSL listener on %s: %v", addr, err)
return
}
s.sslListener = listener
log.Printf("IRC SSL server listening on %s", addr)
for {
select {
case <-s.shutdown:
return
default:
conn, err := listener.Accept()
if err != nil {
continue
}
client := NewClient(conn, s)
s.AddClient(client)
go client.Handle()
}
}
}
func (s *Server) pingRoutine() {
ticker := time.NewTicker(60 * time.Second) // Reduced frequency to prevent contention
defer ticker.Stop()
// Add a health check ticker that runs less frequently
healthTicker := time.NewTicker(5 * time.Minute) // Much less frequent health checks
defer healthTicker.Stop()
for {
select {
case <-s.shutdown:
return
case <-ticker.C:
// Run ping check in a goroutine to prevent blocking
go s.performPingCheck()
case <-healthTicker.C:
// Run health check in a goroutine to prevent blocking
go s.performHealthCheck()
}
}
}
// performPingCheck handles the periodic ping checking
func (s *Server) performPingCheck() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in performPingCheck: %v", r)
}
}()
// Get a snapshot of clients without holding the lock for long
s.mu.RLock()
clientIDs := make([]string, 0, len(s.clients))
for clientID := range s.clients {
clientIDs = append(clientIDs, clientID)
}
s.mu.RUnlock()
// Process clients individually to prevent blocking
for _, clientID := range clientIDs {
func() {
// Get client safely
s.mu.RLock()
client := s.clients[clientID]
s.mu.RUnlock()
if client == nil || !client.IsRegistered() || !client.IsConnected() {
return
}
// Check activity without holding client lock for long
client.mu.RLock()
lastActivity := client.lastActivity
client.mu.RUnlock()
// Only send ping if client hasn't been active recently
if time.Since(lastActivity) > 120*time.Second {
// Send ping in a non-blocking way
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic sending ping to %s: %v", client.getClientInfo(), r)
}
}()
client.SendMessage(s.pingMessage) // Use pre-computed message
}()
}
}()
}
}
// performHealthCheck checks all clients for health issues
func (s *Server) performHealthCheck() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in performHealthCheck: %v", r)
}
}()
// Get a snapshot of clients without holding the lock for long
s.mu.RLock()
clientIDs := make([]string, 0, len(s.clients))
for clientID := range s.clients {
clientIDs = append(clientIDs, clientID)
}
totalClients := len(s.clients)
s.mu.RUnlock()
unhealthyClients := 0
disconnectedClients := []string{} // Store client IDs instead of pointers
// Process clients in batches to prevent overwhelming the system
batchSize := 50
for i := 0; i < len(clientIDs); i += batchSize {
end := i + batchSize
if end > len(clientIDs) {
end = len(clientIDs)
}
batch := clientIDs[i:end]
for _, clientID := range batch {
func() {
// Get client safely
s.mu.RLock()
client := s.clients[clientID]
s.mu.RUnlock()
if client == nil {
return
}
healthy, reason := client.HealthCheck()
if !healthy {
unhealthyClients++
// Force disconnect clients that are definitely problematic
if strings.Contains(reason, "disconnected") ||
strings.Contains(reason, "nil") ||
strings.Contains(reason, "write errors") ||
strings.Contains(reason, "registration timeout") {
disconnectedClients = append(disconnectedClients, clientID)
log.Printf("Marking client %s for disconnection: %s", client.getClientInfo(), reason)
}
}
}()
}
// Small delay between batches to prevent overwhelming
time.Sleep(10 * time.Millisecond)
}
// Disconnect problematic clients in a separate goroutine
if len(disconnectedClients) > 0 {
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic during client disconnection: %v", r)
}
}()
for _, clientID := range disconnectedClients {
s.mu.RLock()
client := s.clients[clientID]
s.mu.RUnlock()
if client != nil {
client.ForceDisconnect("Connection health check failed")
}
}
}()
}
// Log health statistics
if totalClients > 0 {
healthyPercentage := float64(totalClients-unhealthyClients) / float64(totalClients) * 100
log.Printf("Client health check: %d total, %d healthy (%.1f%%), %d marked for disconnection",
totalClients, totalClients-unhealthyClients, healthyPercentage, len(disconnectedClients))
}
// Log detailed server statistics every 5 minutes for monitoring
s.LogServerStats()
}
func (s *Server) AddClient(client *Client) {
s.mu.Lock()
defer s.mu.Unlock()
// Enhanced validation before adding client
if client == nil {
log.Printf("Attempted to add nil client")
return
}
if client.conn == nil {
log.Printf("Attempted to add client with nil connection")
return
}
// Check client limit
if len(s.clients) >= s.config.Limits.MaxClients {
log.Printf("Server full, rejecting client from %s", client.getClientInfo())
client.SendMessage("ERROR :Server full")
if client.conn != nil {
client.conn.Close()
}
return
}
// Check for duplicate client IDs (shouldn't happen but defensive programming)
if _, exists := s.clients[client.clientID]; exists {
log.Printf("Duplicate client ID detected: %s, generating new ID", client.clientID)
// Generate a new unique ID
client.clientID = fmt.Sprintf("%s_%d_%d", client.host, time.Now().Unix(), len(s.clients))
}
s.clients[client.clientID] = client
log.Printf("Added client %s (total clients: %d/%d)", client.getClientInfo(), len(s.clients), s.config.Limits.MaxClients)
}
func (s *Server) RemoveClient(client *Client) {
s.mu.Lock()
delete(s.clients, client.clientID)
s.mu.Unlock()
// Send snomask notification for client disconnect (after releasing the lock)
if client.IsRegistered() {
s.sendSnomask('c', fmt.Sprintf("Client disconnect: %s (%s@%s)",
client.Nick(), client.User(), client.Host()))
}
}
// GetServerStats returns detailed server statistics for monitoring
func (s *Server) GetServerStats() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
stats := make(map[string]interface{})
// Basic counts
stats["total_clients"] = len(s.clients)
stats["total_channels"] = len(s.channels)
// Client statistics
registeredClients := 0
operatorClients := 0
sslClients := 0
unhealthyClients := 0
for _, client := range s.clients {
if client.IsRegistered() {
registeredClients++
}
if client.IsOper() {
operatorClients++
}
if client.IsSSL() {
sslClients++
}
// Quick health check
if healthy, _ := client.HealthCheck(); !healthy {
unhealthyClients++
}
}
stats["registered_clients"] = registeredClients
stats["operator_clients"] = operatorClients
stats["ssl_clients"] = sslClients
stats["unhealthy_clients"] = unhealthyClients
// Channel statistics
totalChannelUsers := 0
for _, channel := range s.channels {
totalChannelUsers += len(channel.GetClients())
}
stats["total_channel_users"] = totalChannelUsers
// Server linking statistics
stats["linked_servers"] = len(s.linkedServers)
return stats
}
// LogServerStats logs current server statistics
func (s *Server) LogServerStats() {
stats := s.GetServerStats()
log.Printf("Server Statistics: %d clients (%d registered, %d operators, %d SSL, %d unhealthy), %d channels, %d linked servers",
stats["total_clients"], stats["registered_clients"], stats["operator_clients"],
stats["ssl_clients"], stats["unhealthy_clients"], stats["total_channels"], stats["linked_servers"])
}
// sendSnomask sends a server notice to operators watching a specific snomask
func (s *Server) sendSnomask(snomask rune, message string) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, client := range s.clients {
if client.IsOper() && client.HasSnomask(snomask) {
client.SendMessage(fmt.Sprintf(":%s NOTICE %s :*** %s",
s.config.Server.Name, client.Nick(), message))
}
}
}
// ReloadConfig reloads the server configuration
func (s *Server) ReloadConfig() error {
config, err := LoadConfig("config.json")
if err != nil {
return fmt.Errorf("failed to load config: %v", err)
}
s.mu.Lock()
s.config = config
s.mu.Unlock()
return nil
}
func (s *Server) GetClient(nick string) *Client {
s.mu.RLock()
defer s.mu.RUnlock()
for _, client := range s.clients {
if strings.EqualFold(client.Nick(), nick) {
return client
}
}
return nil
}
func (s *Server) GetClientByHost(host string) *Client {
s.mu.RLock()
defer s.mu.RUnlock()
for _, client := range s.clients {
if client.Host() == host {
return client
}
}
return nil
}
func (s *Server) GetClientByID(clientID string) *Client {
s.mu.RLock()
defer s.mu.RUnlock()
return s.clients[clientID]
}
func (s *Server) GetClients() map[string]*Client {
s.mu.RLock()
defer s.mu.RUnlock()
clients := make(map[string]*Client)
for clientID, client := range s.clients {
clients[clientID] = client
}
return clients
}
func (s *Server) GetChannel(name string) *Channel {
s.mu.RLock()
defer s.mu.RUnlock()
return s.channels[strings.ToLower(name)]
}
func (s *Server) GetOrCreateChannel(name string) *Channel {
s.mu.Lock()
defer s.mu.Unlock()
channelName := strings.ToLower(name)
if channel, exists := s.channels[channelName]; exists {
return channel
}
// Create new channel
channel := NewChannel(name)
// Set default modes
for _, mode := range s.config.Channels.DefaultModes {
if mode != '+' {
channel.SetMode(rune(mode), true)
}
}
s.channels[channelName] = channel
return channel
}
func (s *Server) RemoveChannel(name string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.channels, strings.ToLower(name))
}
func (s *Server) GetChannels() map[string]*Channel {
s.mu.RLock()
defer s.mu.RUnlock()
channels := make(map[string]*Channel)
for name, channel := range s.channels {
channels[name] = channel
}
return channels
}
func (s *Server) CreateChannel(name string) *Channel {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.channels) >= s.config.Limits.MaxChannels {
return nil
}
channel := NewChannel(name)
s.channels[strings.ToLower(name)] = channel
return channel
}
func (s *Server) GetClientCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.clients)
}
func (s *Server) GetChannelCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.channels)
}
func (s *Server) IsNickInUse(nick string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
for _, client := range s.clients {
if strings.EqualFold(client.Nick(), nick) {
return true
}
}
return false
}
func (s *Server) HandleMessage(client *Client, message string) {
// Parse IRCv3 message tags if present
var tags map[string]string
var actualMessage string
if strings.HasPrefix(message, "@") {
// Message has IRCv3 tags
spaceIndex := strings.Index(message, " ")
if spaceIndex == -1 {
// Message is only tags (like typing indicators), this is valid - just ignore
if DebugMode {
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
}
return
}
tagString := message[1:spaceIndex] // Remove @ prefix
actualMessage = strings.TrimSpace(message[spaceIndex+1:])
// Parse tags
tags = make(map[string]string)
tagPairs := strings.Split(tagString, ";")
for _, pair := range tagPairs {
if strings.Contains(pair, "=") {
kv := strings.SplitN(pair, "=", 2)
tags[kv[0]] = kv[1]
} else {
tags[pair] = ""
}
}
// If actualMessage is empty or just a colon, it's a tags-only message
if actualMessage == "" || actualMessage == ":" {
if DebugMode {
log.Printf("<<< RECV from %s: %s (tags-only message, ignoring)", client.Host(), message)
}
return
}
} else {
actualMessage = message
}
parts := strings.Fields(actualMessage)
if len(parts) == 0 {
return
}
command := strings.ToUpper(parts[0])
// Log the command for debugging
if DebugMode {
log.Printf("<<< RECV from %s: %s", client.Host(), message)
} else {
log.Printf("Client %s: %s", client.Host(), message)
}
switch command {
case "CAP":
client.handleCap(parts)
case "NICK":
client.handleNick(parts)
case "USER":
client.handleUser(parts)
case "PING":
client.handlePing(parts)
case "PONG":
client.handlePong(parts)
case "JOIN":
client.handleJoin(parts)
case "PART":
client.handlePart(parts)
case "PRIVMSG":
client.handlePrivmsg(parts)
case "NOTICE":
client.handleNotice(parts)
case "TAGMSG":
client.handleTagmsg(parts, tags)
case "WHO":
client.handleWho(parts)
case "WHOIS":
client.handleWhois(parts)
case "NAMES":
client.handleNames(parts)
case "MODE":
client.handleMode(parts)
case "OPER":
client.handleOper(parts)
case "SNOMASK":
client.handleSnomask(parts)
case "GLOBALNOTICE":
client.handleGlobalNotice(parts)
case "OPERWALL":
client.handleOperWall(parts)
case "WALLOPS":
client.handleWallops(parts)
case "REHASH":
client.handleRehash()
case "TRACE":
client.handleTrace(parts)
case "HELPOP":
client.handleHelpop(parts)
case "TOPIC":
client.handleTopic(parts)
case "KICK":
client.handleKick(parts)
case "INVITE":
client.handleInvite(parts)
case "AWAY":
client.handleAway(parts)
case "LIST":
client.handleList()
case "KILL":
client.handleKill(parts)
case "QUIT":
client.handleQuit(parts)
case "CONNECT":
client.handleConnect(parts)
case "SQUIT":
client.handleSquit(parts)
case "LINKS":
client.handleLinks()
case "USERHOST":
client.handleUserhost(parts)
case "ISON":
client.handleIson(parts)
case "TIME":
client.handleTime()
case "VERSION":
client.handleVersion()
case "ADMIN":
client.handleAdmin()
case "INFO":
client.handleInfo()
case "LUSERS":
client.handleLusers()
case "STATS":
client.handleStats(parts)
case "SILENCE":
client.handleSilence(parts)
case "MONITOR":
client.handleMonitor(parts)
case "AUTHENTICATE":
client.handleAuthenticate(parts)
// Services/Admin Commands
case "CHGHOST":
client.handleChghost(parts)
case "SVSNICK":
client.handleSvsnick(parts)
case "SVSMODE":
client.handleSvsmode(parts)
case "SAMODE":
client.handleSamode(parts)
case "SANICK":
client.handleSanick(parts)
case "SAKICK":
client.handleSakick(parts)
case "SAPART":
client.handleSapart(parts)
case "SAJOIN":
client.handleSajoin(parts)
case "WHOWAS":
client.handleWhowas(parts)
case "MOTD":
client.handleMotd()
case "RULES":
client.handleRules()
case "MAP":
client.handleMap()
case "KNOCK":
client.handleKnock(parts)
case "SETNAME":
client.handleSetname(parts)
case "DIE":
client.handleDie()
default:
client.SendNumeric(ERR_UNKNOWNCOMMAND, command+" :Unknown command")
}
}
// Shutdown gracefully shuts down the server
func (s *Server) Shutdown() {
log.Println("Initiating graceful shutdown...")
// Stop health monitoring first to prevent interference
if s.healthMonitor != nil {
s.healthMonitor.Stop()
}
// Close listeners immediately to stop accepting new connections
go func() {
if s.listener != nil {
s.listener.Close()
}
if s.sslListener != nil {
s.sslListener.Close()
}
if s.serverListener != nil {
s.serverListener.Close()
}
}()
// Signal shutdown to all goroutines (non-blocking)
select {
case <-s.shutdown:
// Already closed
default:
close(s.shutdown)
}
// Disconnect all linked servers in background
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic during server disconnection: %v", r)
}
}()
s.mu.RLock()
linkedServers := make([]*LinkedServer, 0, len(s.linkedServers))
for _, server := range s.linkedServers {
linkedServers = append(linkedServers, server)
}
s.mu.RUnlock()
for _, linkedServer := range linkedServers {
linkedServer.Disconnect()
}
}()
// Notify and disconnect all clients with timeout
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic during client disconnection: %v", r)
}
}()
// Get client IDs without holding lock for long
s.mu.RLock()
clientIDs := make([]string, 0, len(s.clients))
for clientID := range s.clients {
clientIDs = append(clientIDs, clientID)
}
s.mu.RUnlock()
// Disconnect clients in batches to prevent overwhelming
batchSize := 10
for i := 0; i < len(clientIDs); i += batchSize {
end := i + batchSize
if end > len(clientIDs) {
end = len(clientIDs)
}
batch := clientIDs[i:end]
for _, clientID := range batch {
s.mu.RLock()
client := s.clients[clientID]
s.mu.RUnlock()
if client != nil {
// Send shutdown message with timeout
go func(c *Client) {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic notifying client during shutdown: %v", r)
}
}()
c.SendMessage("ERROR :Server shutting down")
// Give client time to process message
time.Sleep(100 * time.Millisecond)
c.ForceDisconnect("Server shutdown")
}(client)
}
}
// Small delay between batches
time.Sleep(50 * time.Millisecond)
}
}()
// Give everything time to shut down gracefully
time.Sleep(2 * time.Second)
log.Println("Server shutdown complete")
}
// Server linking methods
// startServerListener starts listening for incoming server connections
func (s *Server) startServerListener() {
if !s.config.Linking.Enable {
return
}
addr := fmt.Sprintf("%s:%d", s.config.Server.Listen.Host, s.config.Linking.ServerPort)
listener, err := net.Listen("tcp", addr)
if err != nil {
log.Printf("Failed to start server listener on %s: %v", addr, err)
return
}
s.serverListener = listener
log.Printf("IRC server listening for server links on %s", addr)
for {
select {
case <-s.shutdown:
return
default:
conn, err := listener.Accept()
if err != nil {
continue
}
log.Printf("Incoming server connection from %s", conn.RemoteAddr())
go s.handleIncomingServer(conn)
}
}
}
// startAutoConnections attempts to connect to configured auto-connect servers
func (s *Server) startAutoConnections() {
if !s.config.Linking.Enable {
return
}
// Wait a bit before starting auto-connections
time.Sleep(5 * time.Second)
for _, link := range s.config.Linking.Links {
if link.AutoConnect {
go s.connectToServer(link.Name, link.Host, link.Port, link.Password, link.Hub, link.Description)
}
}
}
// connectToServer connects to a remote server
func (s *Server) connectToServer(name, host string, port int, password string, hub bool, description string) {
linkedServer := NewLinkedServer(name, host, port, password, hub, description, s)
s.mu.Lock()
s.linkedServers[name] = linkedServer
s.mu.Unlock()
for {
select {
case <-s.shutdown:
return
default:
if !linkedServer.IsConnected() {
log.Printf("Attempting to connect to server %s at %s:%d", name, host, port)
if err := linkedServer.Connect(); err != nil {
log.Printf("Failed to connect to server %s: %v", name, err)
time.Sleep(30 * time.Second) // Wait before retry
continue
}
log.Printf("Successfully connected to server %s", name)
}
// Sleep before checking again
time.Sleep(60 * time.Second)
}
}
}
// handleIncomingServer handles an incoming server connection
func (s *Server) handleIncomingServer(conn net.Conn) {
defer conn.Close()
log.Printf("Handling incoming server connection from %s", conn.RemoteAddr())
// Create a temporary linked server for authentication
tempServer := &LinkedServer{
conn: conn,
server: s,
connected: true,
}
// Handle the connection (this will process authentication)
tempServer.Handle()
}
// AddLinkedServer adds a linked server to the server
func (s *Server) AddLinkedServer(linkedServer *LinkedServer) {
s.mu.Lock()
defer s.mu.Unlock()
s.linkedServers[linkedServer.Name()] = linkedServer
}
// RemoveLinkedServer removes a linked server
func (s *Server) RemoveLinkedServer(name string) {
s.mu.Lock()
defer s.mu.Unlock()
if linkedServer, exists := s.linkedServers[name]; exists {
linkedServer.Disconnect()
delete(s.linkedServers, name)
}
}
// GetLinkedServer returns a linked server by name
func (s *Server) GetLinkedServer(name string) *LinkedServer {
s.mu.RLock()
defer s.mu.RUnlock()
return s.linkedServers[name]
}
// GetLinkedServers returns all linked servers
func (s *Server) GetLinkedServers() map[string]*LinkedServer {
s.mu.RLock()
defer s.mu.RUnlock()
servers := make(map[string]*LinkedServer)
for name, server := range s.linkedServers {
servers[name] = server
}
return servers
}
// BroadcastToServers sends a message to all linked servers
func (s *Server) BroadcastToServers(message string) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, linkedServer := range s.linkedServers {
if linkedServer.IsConnected() {
linkedServer.SendMessage(message)
}
}
}
// Ban management functions
// parseDuration parses duration strings like "1d", "2h", "30m", "0" (permanent)
func parseDuration(durationStr string) time.Duration {
if durationStr == "0" {
return 0 // Permanent
}
if len(durationStr) < 2 {
return 24 * time.Hour // Default to 1 day
}
unit := durationStr[len(durationStr)-1]
valueStr := durationStr[:len(durationStr)-1]
var value int
fmt.Sscanf(valueStr, "%d", &value)
switch unit {
case 's':
return time.Duration(value) * time.Second
case 'm':
return time.Duration(value) * time.Minute
case 'h':
return time.Duration(value) * time.Hour
case 'd':
return time.Duration(value) * 24 * time.Hour
case 'w':
return time.Duration(value) * 7 * 24 * time.Hour
default:
return 24 * time.Hour // Default to 1 day
}
}