From c7930a36b95874662bf50edb940fe266f1c7b012 Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 01:05:31 +0200 Subject: [PATCH 1/7] daemonize support --- main.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/main.go b/main.go index 0f1cb0e..1205f63 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "log" "net/http" "os" + "os/exec" "os/signal" "strings" "sync" @@ -45,6 +46,7 @@ var ( tokenRevalidationPeriod time.Duration logFile string logLevel string + daemonize bool ) type wsConnection struct { @@ -503,6 +505,34 @@ func startTokenRevalidation(interval time.Duration) { }() } +func daemonizeSelf() { + if os.Getenv("DAEMONIZED") == "1" { + return + } + + exe, err := os.Executable() + if err != nil { + log.Fatalf("cannot get executable: %v", err) + } + + args := []string{} + for _, a := range os.Args[1:] { + if a != "-daemon" { + args = append(args, a) + } + } + + cmd := exec.Command(exe, args...) + cmd.Env = append(os.Environ(), "DAEMONIZED=1") + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to daemonize: %v", err) + } + + os.Exit(0) +} + // /////////////////// // MAIN // /////////////////// @@ -521,6 +551,7 @@ func main() { flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "Maximum concurrent WebSocket connections per player") flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "Period for WS token revalidation (e.g., 30s, 1m)") flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "Duration that messages will be stored offline (e.g., 30s, 1m)") + flag.BoolVar(&daemonize, "daemon", false, "Run as daemon (background process)") flag.Parse() if origins != "" { @@ -545,6 +576,10 @@ func main() { } logrus.SetLevel(level) + if daemonize { + daemonizeSelf() + } + if err := initDB(); err != nil { log.Fatal(err) } From 9b64e33e64c410af716cd169dc5866fb92c41480 Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 01:05:52 +0200 Subject: [PATCH 2/7] add pid writting support --- main.go | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 1205f63..d34d8c1 100644 --- a/main.go +++ b/main.go @@ -47,6 +47,7 @@ var ( logFile string logLevel string daemonize bool + pidFile string ) type wsConnection struct { @@ -505,6 +506,28 @@ func startTokenRevalidation(interval time.Duration) { }() } +func writePIDFile(path string) error { + pid := os.Getpid() + data := []byte(fmt.Sprintf("%d\n", pid)) + return os.WriteFile(path, data, 0644) +} + +func removePIDFile(path string) { + _ = os.Remove(path) +} + +func pidFileExists(path string) (int, bool) { + data, err := os.ReadFile(path) + if err != nil { + return 0, false + } + var pid int + if _, err := fmt.Sscanf(string(data), "%d", &pid); err != nil { + return 0, false + } + return pid, true +} + func daemonizeSelf() { if os.Getenv("DAEMONIZED") == "1" { return @@ -545,6 +568,7 @@ func main() { flag.StringVar(&origins, "origins", "", "Allowed WS origins") flag.StringVar(&logFile, "log-file", "", "Path to log file (default: stdout)") flag.StringVar(&logLevel, "log-level", "info", "Log level (panic, fatal, error, warn, info, debug, trace)") + flag.StringVar(&pidFile, "pid-file", "", "Path to PID file (daemon mode only)") flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "Maximum queued messages per player") flag.IntVar(&rateLimit, "rate-limit", 10, "Number of messages allowed per rate-period per server token") flag.DurationVar(&ratePeriod, "rate-period", time.Second, "Duration for rate limiting (e.g., 1s, 500ms)") @@ -576,10 +600,29 @@ func main() { } logrus.SetLevel(level) + if pidFile != "" { + if pid, ok := pidFileExists(pidFile); ok { + log.Fatalf("pid file already exists for PID: %d", pid) + } + } + if daemonize { daemonizeSelf() } + if pidFile != "" { + if err := writePIDFile(pidFile); err != nil { + logrus.Fatalf("failed to write pid file: %v", err) + } + defer removePIDFile(pidFile) + } + + if pidFile != "" { + if err := writePIDFile(pidFile); err != nil { + logrus.Fatalf("failed to write pid file: %v", err) + } + } + if err := initDB(); err != nil { log.Fatal(err) } @@ -633,5 +676,8 @@ func main() { }() logrus.Infof("Server listening on %s", serverAddr) - log.Fatal(server.ListenAndServe()) + err = server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + logrus.Fatalf("server error: %v", err) + } } From b194117faeda111c1d1173b5583401ccd3dc7783 Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 01:42:14 +0200 Subject: [PATCH 3/7] change order to make defer work --- main.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index d34d8c1..17d8595 100644 --- a/main.go +++ b/main.go @@ -610,6 +610,10 @@ func main() { daemonizeSelf() } + if err := initDB(); err != nil { + log.Fatal(err) + } + if pidFile != "" { if err := writePIDFile(pidFile); err != nil { logrus.Fatalf("failed to write pid file: %v", err) @@ -623,10 +627,6 @@ func main() { } } - if err := initDB(); err != nil { - log.Fatal(err) - } - initMetrics() // cleanup expired offline messages From 480ea9000416c1bf5d29f9f70c71f651a8d214c2 Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 01:57:00 +0200 Subject: [PATCH 4/7] attempt to reduce cyclomatic complexity --- main.go | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/main.go b/main.go index 17d8595..38bf0a1 100644 --- a/main.go +++ b/main.go @@ -112,6 +112,30 @@ type limiter struct { last time.Time } +func parseFlags() { + var origins string + + flag.StringVar(&dbDriver, "db", "sqlite", "Database driver") + flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN") + flag.StringVar(&serverAddr, "addr", ":8080", "Server address") + flag.StringVar(&origins, "origins", "", "Allowed WS origins") + flag.StringVar(&logFile, "log-file", "", "Path to log file") + flag.StringVar(&logLevel, "log-level", "info", "Log level") + flag.StringVar(&pidFile, "pid-file", "", "Path to PID file") + flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "") + flag.IntVar(&rateLimit, "rate-limit", 10, "") + flag.DurationVar(&ratePeriod, "rate-period", time.Second, "") + flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "") + flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "") + flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "") + flag.BoolVar(&daemonize, "daemon", false, "") + flag.Parse() + + if origins != "" { + allowedOrigins = strings.Split(origins, ",") + } +} + func allow(key string, rate int, per time.Duration) bool { lm.Lock() defer lm.Unlock() @@ -560,27 +584,7 @@ func daemonizeSelf() { // MAIN // /////////////////// func main() { - var origins string - - flag.StringVar(&dbDriver, "db", "sqlite", "Database driver") - flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN") - flag.StringVar(&serverAddr, "addr", ":8080", "Server address") - flag.StringVar(&origins, "origins", "", "Allowed WS origins") - flag.StringVar(&logFile, "log-file", "", "Path to log file (default: stdout)") - flag.StringVar(&logLevel, "log-level", "info", "Log level (panic, fatal, error, warn, info, debug, trace)") - flag.StringVar(&pidFile, "pid-file", "", "Path to PID file (daemon mode only)") - flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "Maximum queued messages per player") - flag.IntVar(&rateLimit, "rate-limit", 10, "Number of messages allowed per rate-period per server token") - flag.DurationVar(&ratePeriod, "rate-period", time.Second, "Duration for rate limiting (e.g., 1s, 500ms)") - flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "Maximum concurrent WebSocket connections per player") - flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "Period for WS token revalidation (e.g., 30s, 1m)") - flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "Duration that messages will be stored offline (e.g., 30s, 1m)") - flag.BoolVar(&daemonize, "daemon", false, "Run as daemon (background process)") - flag.Parse() - - if origins != "" { - allowedOrigins = strings.Split(origins, ",") - } + parseFlags() logrus.SetFormatter(&logrus.JSONFormatter{}) From 212d79b9ec6058634b2d42a9d2ab42aa78ede3c6 Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 02:16:49 +0200 Subject: [PATCH 5/7] move setupLogging into its own function --- main.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/main.go b/main.go index 38bf0a1..4c3e19d 100644 --- a/main.go +++ b/main.go @@ -579,13 +579,7 @@ func daemonizeSelf() { os.Exit(0) } - -// /////////////////// -// MAIN -// /////////////////// -func main() { - parseFlags() - +func setupLogging() { logrus.SetFormatter(&logrus.JSONFormatter{}) if logFile != "" { @@ -603,6 +597,14 @@ func main() { log.Fatalf("invalid log level: %s", logLevel) } logrus.SetLevel(level) +} + +// /////////////////// +// MAIN +// /////////////////// +func main() { + parseFlags() + setupLogging() if pidFile != "" { if pid, ok := pidFileExists(pidFile); ok { @@ -680,7 +682,7 @@ func main() { }() logrus.Infof("Server listening on %s", serverAddr) - err = server.ListenAndServe() + err := server.ListenAndServe() if err != nil && err != http.ErrServerClosed { logrus.Fatalf("server error: %v", err) } From 7a154a3da09f7e4cfc4022f3a72ee9e4fd2499fa Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 02:19:59 +0200 Subject: [PATCH 6/7] move pid file handling --- main.go | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index 4c3e19d..adac4b8 100644 --- a/main.go +++ b/main.go @@ -599,18 +599,32 @@ func setupLogging() { logrus.SetLevel(level) } +func handlePIDFile() { + if pidFile == "" { + return + } + + if pid, ok := pidFileExists(pidFile); ok { + log.Fatalf("pid file already exists for PID: %d", pid) + } + + if err := writePIDFile(pidFile); err != nil { + logrus.Fatalf("failed to write pid file: %v", err) + } + + go func() { + <-context.Background().Done() + removePIDFile(pidFile) + }() +} + // /////////////////// // MAIN // /////////////////// func main() { parseFlags() setupLogging() - - if pidFile != "" { - if pid, ok := pidFileExists(pidFile); ok { - log.Fatalf("pid file already exists for PID: %d", pid) - } - } + handlePIDFile() if daemonize { daemonizeSelf() @@ -620,19 +634,6 @@ func main() { log.Fatal(err) } - if pidFile != "" { - if err := writePIDFile(pidFile); err != nil { - logrus.Fatalf("failed to write pid file: %v", err) - } - defer removePIDFile(pidFile) - } - - if pidFile != "" { - if err := writePIDFile(pidFile); err != nil { - logrus.Fatalf("failed to write pid file: %v", err) - } - } - initMetrics() // cleanup expired offline messages From 96e6092a52a256aa7145795c9e7b50a2a5173c5a Mon Sep 17 00:00:00 2001 From: Pantelis Roditis Date: Sat, 20 Dec 2025 14:18:39 +0200 Subject: [PATCH 7/7] Split certain operations so that we can display certain errors to stderr instead without losing our defer methods --- main.go | 151 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 94 insertions(+), 57 deletions(-) diff --git a/main.go b/main.go index adac4b8..2621512 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "errors" "flag" "fmt" - "log" "net/http" "os" "os/exec" @@ -104,14 +103,13 @@ var ( lm sync.Mutex ) -// /////////////////// -// RATE LIMITER -// /////////////////// type limiter struct { tokens int last time.Time } +// parseFlags parses command-line flags into global configuration variables. +// It supports database driver, DSN, server address, log file/level, PID file, and various WS server limits. func parseFlags() { var origins string @@ -136,6 +134,8 @@ func parseFlags() { } } +// allow implements a simple token-based rate limiter for a given key. +// Returns true if the action is allowed, false if the rate limit has been exceeded. func allow(key string, rate int, per time.Duration) bool { lm.Lock() defer lm.Unlock() @@ -188,9 +188,8 @@ type pendingMessage struct { timestamp time.Time } -// /////////////////// -// DB -// /////////////////// +// initDB initializes the global database connection based on dbDriver and dbDSN. +// Returns an error if the driver is unsupported or if the connection cannot be established. func initDB() error { var err error @@ -214,7 +213,8 @@ func initDB() error { return db.Ping() } -// validateToken checks token validity in DB +// validateToken checks whether a given token is valid in the database. +// Returns the associated player/subject ID and true if valid, or empty string and false if invalid. func validateToken(token string, isServer bool) (string, bool) { const q = ` SELECT IFNULL(player_id, subject_id) @@ -235,11 +235,8 @@ func validateToken(token string, isServer bool) (string, bool) { return "", false } -// /////////////////// -// CONNECTION MANAGEMENT -// /////////////////// - -// registerConnection registers a WS connection and stores its token +// registerConnection registers a websocket connection for a player, storing the associated token. +// It also increments the active connections metric and flushes any pending messages to the new connection. func registerConnection(playerID string, c *websocket.Conn, token string) { mu.Lock() if players[playerID] == nil { @@ -253,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) { flushPendingMessages(playerID, c) } -// unregisterConnection removes a WS connection +// unregisterConnection removes a websocket connection for a player and decrements the active connections metric. +// If no connections remain for the player, the player's entry is removed from the players map. func unregisterConnection(playerID string, c *websocket.Conn) { mu.Lock() defer mu.Unlock() @@ -264,7 +262,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) { connections.Dec() } -// closeAllConnections closes all WS connections (on shutdown) +// closeAllConnections closes all active websocket connections for all players. +// Typically used during server shutdown. func closeAllConnections() { mu.Lock() defer mu.Unlock() @@ -275,7 +274,8 @@ func closeAllConnections() { } } -// flushPendingMessages sends queued messages to a newly connected player +// flushPendingMessages sends any queued offline messages to a newly connected websocket. +// Messages older than offlineTTL are ignored and removed. func flushPendingMessages(playerID string, c *websocket.Conn) { pendingMu.Lock() msgs := pendingMessages[playerID] @@ -298,9 +298,9 @@ func flushPendingMessages(playerID string, c *websocket.Conn) { pendingMu.Unlock() } -// /////////////////// -// WEBSOCKET HANDLER -// /////////////////// +// wsHandler handles incoming websocket upgrade requests from clients. +// Validates the token, enforces connection limits, sets up heartbeat, and reads messages. +// Connections are automatically unregistered on disconnect. func wsHandler(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { @@ -372,9 +372,9 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { }).Info("Player disconnected") } -// /////////////////// -// PUBLISH HANDLER -// /////////////////// +// publishHandler handles incoming messages from authorized servers to a specific player. +// Validates server token, enforces rate limits, delivers message immediately if the player is connected, +// or queues the message for offline delivery if not. func publishHandler(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { @@ -443,9 +443,9 @@ func publishHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// /////////////////// -// BROADCAST HANDLER -// /////////////////// +// broadcastHandler handles incoming broadcast messages from authorized servers. +// Can target a specific player or all connected players. +// Enforces rate limits and increments metrics for delivered messages. func broadcastHandler(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { @@ -494,18 +494,13 @@ func broadcastHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// /////////////////// -// METRICS -// /////////////////// +// initMetrics registers Prometheus metrics for connections, messages published, and messages delivered. func initMetrics() { prometheus.MustRegister(connections, messagesPublished, messagesDelivered) } -// /////////////////// -// TOKEN REVALIDATION -// /////////////////// - -// startTokenRevalidation periodically checks all WS tokens and closes invalid ones +// startTokenRevalidation periodically validates all active websocket tokens. +// Invalid tokens cause connections to be closed and removed. func startTokenRevalidation(interval time.Duration) { ticker := time.NewTicker(interval) go func() { @@ -530,16 +525,22 @@ func startTokenRevalidation(interval time.Duration) { }() } +// writePIDFile writes the current process PID to the specified file path. +// Returns an error if writing fails. func writePIDFile(path string) error { pid := os.Getpid() data := []byte(fmt.Sprintf("%d\n", pid)) return os.WriteFile(path, data, 0644) } +// removePIDFile deletes the PID file at the specified path. +// Any errors are ignored. func removePIDFile(path string) { _ = os.Remove(path) } +// pidFileExists checks if the PID file exists and reads its PID. +// Returns the PID and true if the file exists and contains a valid integer, otherwise 0 and false. func pidFileExists(path string) (int, bool) { data, err := os.ReadFile(path) if err != nil { @@ -552,14 +553,17 @@ func pidFileExists(path string) (int, bool) { return pid, true } -func daemonizeSelf() { +// daemonizeSelf re-launches the current executable as a background daemon process. +// It returns an error if the executable cannot be determined or if the child process fails to start. +// If successful, the parent process will exit immediately using os.Exit(0) to allow the daemon to continue independently. +func daemonizeSelf() error { if os.Getenv("DAEMONIZED") == "1" { - return + return nil } exe, err := os.Executable() if err != nil { - log.Fatalf("cannot get executable: %v", err) + return fmt.Errorf("cannot get executable path: %w", err) } args := []string{} @@ -574,18 +578,23 @@ func daemonizeSelf() { cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} if err := cmd.Start(); err != nil { - log.Fatalf("failed to daemonize: %v", err) + return fmt.Errorf("failed to daemonize: %w", err) } - os.Exit(0) + os.Exit(0) // safe here, because no defers in daemon parent context + return nil // unreachable, but satisfies compiler } -func setupLogging() { + +// setupLogging configures logrus logging for the application. +// It sets the output destination and log level based on global flags. +// Returns an error if the log file cannot be opened or if the log level is invalid. +func setupLogging() error { logrus.SetFormatter(&logrus.JSONFormatter{}) if logFile != "" { f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - log.Fatalf("failed to open log file %s: %v", logFile, err) + return fmt.Errorf("failed to open log file %s: %w", logFile, err) } logrus.SetOutput(f) } else { @@ -594,49 +603,75 @@ func setupLogging() { level, err := logrus.ParseLevel(strings.ToLower(logLevel)) if err != nil { - log.Fatalf("invalid log level: %s", logLevel) + return fmt.Errorf("invalid log level: %s", logLevel) } logrus.SetLevel(level) + + return nil } -func handlePIDFile() { +// handlePIDFile ensures that the PID file is created and removed properly. +// If the PID file already exists, it returns an error. +// The PID file is automatically removed when the function that called this defers cleanup. +// Returns an error if writing the PID file fails. +func handlePIDFile() error { if pidFile == "" { - return + return nil } if pid, ok := pidFileExists(pidFile); ok { - log.Fatalf("pid file already exists for PID: %d", pid) + return fmt.Errorf("pid file already exists for PID: %d", pid) } if err := writePIDFile(pidFile); err != nil { - logrus.Fatalf("failed to write pid file: %v", err) + return fmt.Errorf("failed to write pid file: %w", err) } - go func() { - <-context.Background().Done() - removePIDFile(pidFile) - }() + // The caller should defer removePIDFile(pidFile) to ensure cleanup + return nil } // /////////////////// // MAIN // /////////////////// func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { parseFlags() - setupLogging() - handlePIDFile() - if daemonize { - daemonizeSelf() + // Handle PID file + if err := handlePIDFile(); err != nil { + return err + } + if pidFile != "" { + defer removePIDFile(pidFile) } + // Setup logging + if err := setupLogging(); err != nil { + return fmt.Errorf("failed to setup logging: %w", err) + } + + // Initialize DB if err := initDB(); err != nil { - log.Fatal(err) + return fmt.Errorf("failed to init DB: %w", err) + } + + // Daemonize if needed + if daemonize { + if err := daemonizeSelf(); err != nil { + return fmt.Errorf("failed to daemonize: %w", err) + } } initMetrics() - // cleanup expired offline messages + // Start offline message cleanup go func() { ticker := time.NewTicker(30 * time.Second) for range ticker.C { @@ -670,7 +705,7 @@ func main() { server := &http.Server{Addr: serverAddr, Handler: mux} - // graceful shutdown + // Graceful shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { @@ -685,6 +720,8 @@ func main() { logrus.Infof("Server listening on %s", serverAddr) err := server.ListenAndServe() if err != nil && err != http.ErrServerClosed { - logrus.Fatalf("server error: %v", err) + return fmt.Errorf("server error: %w", err) } + + return nil }