diff --git a/main.go b/main.go index 0f1cb0e..2621512 100644 --- a/main.go +++ b/main.go @@ -7,9 +7,9 @@ import ( "errors" "flag" "fmt" - "log" "net/http" "os" + "os/exec" "os/signal" "strings" "sync" @@ -45,6 +45,8 @@ var ( tokenRevalidationPeriod time.Duration logFile string logLevel string + daemonize bool + pidFile string ) type wsConnection struct { @@ -101,14 +103,39 @@ var ( lm sync.Mutex ) -// /////////////////// -// RATE LIMITER -// /////////////////// type limiter struct { tokens int last time.Time } +// parseFlags parses command-line flags into global configuration variables. +// It supports database driver, DSN, server address, log file/level, PID file, and various WS server limits. +func parseFlags() { + var origins string + + flag.StringVar(&dbDriver, "db", "sqlite", "Database driver") + flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN") + flag.StringVar(&serverAddr, "addr", ":8080", "Server address") + flag.StringVar(&origins, "origins", "", "Allowed WS origins") + flag.StringVar(&logFile, "log-file", "", "Path to log file") + flag.StringVar(&logLevel, "log-level", "info", "Log level") + flag.StringVar(&pidFile, "pid-file", "", "Path to PID file") + flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "") + flag.IntVar(&rateLimit, "rate-limit", 10, "") + flag.DurationVar(&ratePeriod, "rate-period", time.Second, "") + flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "") + flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "") + flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "") + flag.BoolVar(&daemonize, "daemon", false, "") + flag.Parse() + + if origins != "" { + allowedOrigins = strings.Split(origins, ",") + } +} + +// allow implements a simple token-based rate limiter for a given key. +// Returns true if the action is allowed, false if the rate limit has been exceeded. func allow(key string, rate int, per time.Duration) bool { lm.Lock() defer lm.Unlock() @@ -161,9 +188,8 @@ type pendingMessage struct { timestamp time.Time } -// /////////////////// -// DB -// /////////////////// +// initDB initializes the global database connection based on dbDriver and dbDSN. +// Returns an error if the driver is unsupported or if the connection cannot be established. func initDB() error { var err error @@ -187,7 +213,8 @@ func initDB() error { return db.Ping() } -// validateToken checks token validity in DB +// validateToken checks whether a given token is valid in the database. +// Returns the associated player/subject ID and true if valid, or empty string and false if invalid. func validateToken(token string, isServer bool) (string, bool) { const q = ` SELECT IFNULL(player_id, subject_id) @@ -208,11 +235,8 @@ func validateToken(token string, isServer bool) (string, bool) { return "", false } -// /////////////////// -// CONNECTION MANAGEMENT -// /////////////////// - -// registerConnection registers a WS connection and stores its token +// registerConnection registers a websocket connection for a player, storing the associated token. +// It also increments the active connections metric and flushes any pending messages to the new connection. func registerConnection(playerID string, c *websocket.Conn, token string) { mu.Lock() if players[playerID] == nil { @@ -226,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) { flushPendingMessages(playerID, c) } -// unregisterConnection removes a WS connection +// unregisterConnection removes a websocket connection for a player and decrements the active connections metric. +// If no connections remain for the player, the player's entry is removed from the players map. func unregisterConnection(playerID string, c *websocket.Conn) { mu.Lock() defer mu.Unlock() @@ -237,7 +262,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) { connections.Dec() } -// closeAllConnections closes all WS connections (on shutdown) +// closeAllConnections closes all active websocket connections for all players. +// Typically used during server shutdown. func closeAllConnections() { mu.Lock() defer mu.Unlock() @@ -248,7 +274,8 @@ func closeAllConnections() { } } -// flushPendingMessages sends queued messages to a newly connected player +// flushPendingMessages sends any queued offline messages to a newly connected websocket. +// Messages older than offlineTTL are ignored and removed. func flushPendingMessages(playerID string, c *websocket.Conn) { pendingMu.Lock() msgs := pendingMessages[playerID] @@ -271,9 +298,9 @@ func flushPendingMessages(playerID string, c *websocket.Conn) { pendingMu.Unlock() } -// /////////////////// -// WEBSOCKET HANDLER -// /////////////////// +// wsHandler handles incoming websocket upgrade requests from clients. +// Validates the token, enforces connection limits, sets up heartbeat, and reads messages. +// Connections are automatically unregistered on disconnect. func wsHandler(w http.ResponseWriter, r *http.Request) { token := r.URL.Query().Get("token") if token == "" { @@ -345,9 +372,9 @@ func wsHandler(w http.ResponseWriter, r *http.Request) { }).Info("Player disconnected") } -// /////////////////// -// PUBLISH HANDLER -// /////////////////// +// publishHandler handles incoming messages from authorized servers to a specific player. +// Validates server token, enforces rate limits, delivers message immediately if the player is connected, +// or queues the message for offline delivery if not. func publishHandler(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { @@ -416,9 +443,9 @@ func publishHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// /////////////////// -// BROADCAST HANDLER -// /////////////////// +// broadcastHandler handles incoming broadcast messages from authorized servers. +// Can target a specific player or all connected players. +// Enforces rate limits and increments metrics for delivered messages. func broadcastHandler(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { @@ -467,18 +494,13 @@ func broadcastHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// /////////////////// -// METRICS -// /////////////////// +// initMetrics registers Prometheus metrics for connections, messages published, and messages delivered. func initMetrics() { prometheus.MustRegister(connections, messagesPublished, messagesDelivered) } -// /////////////////// -// TOKEN REVALIDATION -// /////////////////// - -// startTokenRevalidation periodically checks all WS tokens and closes invalid ones +// startTokenRevalidation periodically validates all active websocket tokens. +// Invalid tokens cause connections to be closed and removed. func startTokenRevalidation(interval time.Duration) { ticker := time.NewTicker(interval) go func() { @@ -503,36 +525,76 @@ func startTokenRevalidation(interval time.Duration) { }() } -// /////////////////// -// MAIN -// /////////////////// -func main() { - var origins string +// writePIDFile writes the current process PID to the specified file path. +// Returns an error if writing fails. +func writePIDFile(path string) error { + pid := os.Getpid() + data := []byte(fmt.Sprintf("%d\n", pid)) + return os.WriteFile(path, data, 0644) +} - flag.StringVar(&dbDriver, "db", "sqlite", "Database driver") - flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN") - flag.StringVar(&serverAddr, "addr", ":8080", "Server address") - flag.StringVar(&origins, "origins", "", "Allowed WS origins") - flag.StringVar(&logFile, "log-file", "", "Path to log file (default: stdout)") - flag.StringVar(&logLevel, "log-level", "info", "Log level (panic, fatal, error, warn, info, debug, trace)") - flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "Maximum queued messages per player") - flag.IntVar(&rateLimit, "rate-limit", 10, "Number of messages allowed per rate-period per server token") - flag.DurationVar(&ratePeriod, "rate-period", time.Second, "Duration for rate limiting (e.g., 1s, 500ms)") - flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "Maximum concurrent WebSocket connections per player") - flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "Period for WS token revalidation (e.g., 30s, 1m)") - flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "Duration that messages will be stored offline (e.g., 30s, 1m)") - flag.Parse() +// removePIDFile deletes the PID file at the specified path. +// Any errors are ignored. +func removePIDFile(path string) { + _ = os.Remove(path) +} - if origins != "" { - allowedOrigins = strings.Split(origins, ",") +// pidFileExists checks if the PID file exists and reads its PID. +// Returns the PID and true if the file exists and contains a valid integer, otherwise 0 and false. +func pidFileExists(path string) (int, bool) { + data, err := os.ReadFile(path) + if err != nil { + return 0, false } + var pid int + if _, err := fmt.Sscanf(string(data), "%d", &pid); err != nil { + return 0, false + } + return pid, true +} + +// daemonizeSelf re-launches the current executable as a background daemon process. +// It returns an error if the executable cannot be determined or if the child process fails to start. +// If successful, the parent process will exit immediately using os.Exit(0) to allow the daemon to continue independently. +func daemonizeSelf() error { + if os.Getenv("DAEMONIZED") == "1" { + return nil + } + + exe, err := os.Executable() + if err != nil { + return fmt.Errorf("cannot get executable path: %w", err) + } + + args := []string{} + for _, a := range os.Args[1:] { + if a != "-daemon" { + args = append(args, a) + } + } + + cmd := exec.Command(exe, args...) + cmd.Env = append(os.Environ(), "DAEMONIZED=1") + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to daemonize: %w", err) + } + + os.Exit(0) // safe here, because no defers in daemon parent context + return nil // unreachable, but satisfies compiler +} + +// setupLogging configures logrus logging for the application. +// It sets the output destination and log level based on global flags. +// Returns an error if the log file cannot be opened or if the log level is invalid. +func setupLogging() error { logrus.SetFormatter(&logrus.JSONFormatter{}) if logFile != "" { f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - log.Fatalf("failed to open log file %s: %v", logFile, err) + return fmt.Errorf("failed to open log file %s: %w", logFile, err) } logrus.SetOutput(f) } else { @@ -541,17 +603,75 @@ func main() { level, err := logrus.ParseLevel(strings.ToLower(logLevel)) if err != nil { - log.Fatalf("invalid log level: %s", logLevel) + return fmt.Errorf("invalid log level: %s", logLevel) } logrus.SetLevel(level) + return nil +} + +// handlePIDFile ensures that the PID file is created and removed properly. +// If the PID file already exists, it returns an error. +// The PID file is automatically removed when the function that called this defers cleanup. +// Returns an error if writing the PID file fails. +func handlePIDFile() error { + if pidFile == "" { + return nil + } + + if pid, ok := pidFileExists(pidFile); ok { + return fmt.Errorf("pid file already exists for PID: %d", pid) + } + + if err := writePIDFile(pidFile); err != nil { + return fmt.Errorf("failed to write pid file: %w", err) + } + + // The caller should defer removePIDFile(pidFile) to ensure cleanup + return nil +} + +// /////////////////// +// MAIN +// /////////////////// +func main() { + if err := run(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func run() error { + parseFlags() + + // Handle PID file + if err := handlePIDFile(); err != nil { + return err + } + if pidFile != "" { + defer removePIDFile(pidFile) + } + + // Setup logging + if err := setupLogging(); err != nil { + return fmt.Errorf("failed to setup logging: %w", err) + } + + // Initialize DB if err := initDB(); err != nil { - log.Fatal(err) + return fmt.Errorf("failed to init DB: %w", err) + } + + // Daemonize if needed + if daemonize { + if err := daemonizeSelf(); err != nil { + return fmt.Errorf("failed to daemonize: %w", err) + } } initMetrics() - // cleanup expired offline messages + // Start offline message cleanup go func() { ticker := time.NewTicker(30 * time.Second) for range ticker.C { @@ -585,7 +705,7 @@ func main() { server := &http.Server{Addr: serverAddr, Handler: mux} - // graceful shutdown + // Graceful shutdown quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) go func() { @@ -598,5 +718,10 @@ func main() { }() logrus.Infof("Server listening on %s", serverAddr) - log.Fatal(server.ListenAndServe()) + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("server error: %w", err) + } + + return nil }