Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 184 additions & 59 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"errors"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"sync"
Expand Down Expand Up @@ -45,6 +45,8 @@ var (
tokenRevalidationPeriod time.Duration
logFile string
logLevel string
daemonize bool
pidFile string
)

type wsConnection struct {
Expand Down Expand Up @@ -101,14 +103,39 @@ var (
lm sync.Mutex
)

// ///////////////////
// RATE LIMITER
// ///////////////////
type limiter struct {
tokens int
last time.Time
}

// parseFlags parses command-line flags into global configuration variables.
// It supports database driver, DSN, server address, log file/level, PID file, and various WS server limits.
func parseFlags() {
var origins string

flag.StringVar(&dbDriver, "db", "sqlite", "Database driver")
flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN")
flag.StringVar(&serverAddr, "addr", ":8080", "Server address")
flag.StringVar(&origins, "origins", "", "Allowed WS origins")
flag.StringVar(&logFile, "log-file", "", "Path to log file")
flag.StringVar(&logLevel, "log-level", "info", "Log level")
flag.StringVar(&pidFile, "pid-file", "", "Path to PID file")
flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "")
flag.IntVar(&rateLimit, "rate-limit", 10, "")
flag.DurationVar(&ratePeriod, "rate-period", time.Second, "")
flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "")
flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "")
flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "")
flag.BoolVar(&daemonize, "daemon", false, "")
flag.Parse()

if origins != "" {
allowedOrigins = strings.Split(origins, ",")
}
}

// allow implements a simple token-based rate limiter for a given key.
// Returns true if the action is allowed, false if the rate limit has been exceeded.
func allow(key string, rate int, per time.Duration) bool {
lm.Lock()
defer lm.Unlock()
Expand Down Expand Up @@ -161,9 +188,8 @@ type pendingMessage struct {
timestamp time.Time
}

// ///////////////////
// DB
// ///////////////////
// initDB initializes the global database connection based on dbDriver and dbDSN.
// Returns an error if the driver is unsupported or if the connection cannot be established.
func initDB() error {
var err error

Expand All @@ -187,7 +213,8 @@ func initDB() error {
return db.Ping()
}

// validateToken checks token validity in DB
// validateToken checks whether a given token is valid in the database.
// Returns the associated player/subject ID and true if valid, or empty string and false if invalid.
func validateToken(token string, isServer bool) (string, bool) {
const q = `
SELECT IFNULL(player_id, subject_id)
Expand All @@ -208,11 +235,8 @@ func validateToken(token string, isServer bool) (string, bool) {
return "", false
}

// ///////////////////
// CONNECTION MANAGEMENT
// ///////////////////

// registerConnection registers a WS connection and stores its token
// registerConnection registers a websocket connection for a player, storing the associated token.
// It also increments the active connections metric and flushes any pending messages to the new connection.
func registerConnection(playerID string, c *websocket.Conn, token string) {
mu.Lock()
if players[playerID] == nil {
Expand All @@ -226,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
flushPendingMessages(playerID, c)
}

// unregisterConnection removes a WS connection
// unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
// If no connections remain for the player, the player's entry is removed from the players map.
func unregisterConnection(playerID string, c *websocket.Conn) {
mu.Lock()
defer mu.Unlock()
Expand All @@ -237,7 +262,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
connections.Dec()
}

// closeAllConnections closes all WS connections (on shutdown)
// closeAllConnections closes all active websocket connections for all players.
// Typically used during server shutdown.
func closeAllConnections() {
mu.Lock()
defer mu.Unlock()
Expand All @@ -248,7 +274,8 @@ func closeAllConnections() {
}
}

// flushPendingMessages sends queued messages to a newly connected player
// flushPendingMessages sends any queued offline messages to a newly connected websocket.
// Messages older than offlineTTL are ignored and removed.
func flushPendingMessages(playerID string, c *websocket.Conn) {
pendingMu.Lock()
msgs := pendingMessages[playerID]
Expand All @@ -271,9 +298,9 @@ func flushPendingMessages(playerID string, c *websocket.Conn) {
pendingMu.Unlock()
}

// ///////////////////
// WEBSOCKET HANDLER
// ///////////////////
// wsHandler handles incoming websocket upgrade requests from clients.
// Validates the token, enforces connection limits, sets up heartbeat, and reads messages.
// Connections are automatically unregistered on disconnect.
func wsHandler(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
Expand Down Expand Up @@ -345,9 +372,9 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
}).Info("Player disconnected")
}

// ///////////////////
// PUBLISH HANDLER
// ///////////////////
// publishHandler handles incoming messages from authorized servers to a specific player.
// Validates server token, enforces rate limits, delivers message immediately if the player is connected,
// or queues the message for offline delivery if not.
func publishHandler(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
Expand Down Expand Up @@ -416,9 +443,9 @@ func publishHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

// ///////////////////
// BROADCAST HANDLER
// ///////////////////
// broadcastHandler handles incoming broadcast messages from authorized servers.
// Can target a specific player or all connected players.
// Enforces rate limits and increments metrics for delivered messages.
func broadcastHandler(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
Expand Down Expand Up @@ -467,18 +494,13 @@ func broadcastHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

// ///////////////////
// METRICS
// ///////////////////
// initMetrics registers Prometheus metrics for connections, messages published, and messages delivered.
func initMetrics() {
prometheus.MustRegister(connections, messagesPublished, messagesDelivered)
}

// ///////////////////
// TOKEN REVALIDATION
// ///////////////////

// startTokenRevalidation periodically checks all WS tokens and closes invalid ones
// startTokenRevalidation periodically validates all active websocket tokens.
// Invalid tokens cause connections to be closed and removed.
func startTokenRevalidation(interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
Expand All @@ -503,36 +525,76 @@ func startTokenRevalidation(interval time.Duration) {
}()
}

// ///////////////////
// MAIN
// ///////////////////
func main() {
var origins string
// writePIDFile writes the current process PID to the specified file path.
// Returns an error if writing fails.
func writePIDFile(path string) error {
pid := os.Getpid()
data := []byte(fmt.Sprintf("%d\n", pid))
return os.WriteFile(path, data, 0644)
}

flag.StringVar(&dbDriver, "db", "sqlite", "Database driver")
flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN")
flag.StringVar(&serverAddr, "addr", ":8080", "Server address")
flag.StringVar(&origins, "origins", "", "Allowed WS origins")
flag.StringVar(&logFile, "log-file", "", "Path to log file (default: stdout)")
flag.StringVar(&logLevel, "log-level", "info", "Log level (panic, fatal, error, warn, info, debug, trace)")
flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "Maximum queued messages per player")
flag.IntVar(&rateLimit, "rate-limit", 10, "Number of messages allowed per rate-period per server token")
flag.DurationVar(&ratePeriod, "rate-period", time.Second, "Duration for rate limiting (e.g., 1s, 500ms)")
flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "Maximum concurrent WebSocket connections per player")
flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "Period for WS token revalidation (e.g., 30s, 1m)")
flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "Duration that messages will be stored offline (e.g., 30s, 1m)")
flag.Parse()
// removePIDFile deletes the PID file at the specified path.
// Any errors are ignored.
func removePIDFile(path string) {
_ = os.Remove(path)
}

if origins != "" {
allowedOrigins = strings.Split(origins, ",")
// pidFileExists checks if the PID file exists and reads its PID.
// Returns the PID and true if the file exists and contains a valid integer, otherwise 0 and false.
func pidFileExists(path string) (int, bool) {
data, err := os.ReadFile(path)
if err != nil {
return 0, false
}
var pid int
if _, err := fmt.Sscanf(string(data), "%d", &pid); err != nil {
return 0, false
}
return pid, true
}

// daemonizeSelf re-launches the current executable as a background daemon process.
// It returns an error if the executable cannot be determined or if the child process fails to start.
// If successful, the parent process will exit immediately using os.Exit(0) to allow the daemon to continue independently.
func daemonizeSelf() error {
if os.Getenv("DAEMONIZED") == "1" {
return nil
}

exe, err := os.Executable()
if err != nil {
return fmt.Errorf("cannot get executable path: %w", err)
}

args := []string{}
for _, a := range os.Args[1:] {
if a != "-daemon" {
args = append(args, a)
}
}

cmd := exec.Command(exe, args...)
cmd.Env = append(os.Environ(), "DAEMONIZED=1")
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}

if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to daemonize: %w", err)
}

os.Exit(0) // safe here, because no defers in daemon parent context
return nil // unreachable, but satisfies compiler
}

// setupLogging configures logrus logging for the application.
// It sets the output destination and log level based on global flags.
// Returns an error if the log file cannot be opened or if the log level is invalid.
func setupLogging() error {
logrus.SetFormatter(&logrus.JSONFormatter{})

if logFile != "" {
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
log.Fatalf("failed to open log file %s: %v", logFile, err)
return fmt.Errorf("failed to open log file %s: %w", logFile, err)
}
logrus.SetOutput(f)
} else {
Expand All @@ -541,17 +603,75 @@ func main() {

level, err := logrus.ParseLevel(strings.ToLower(logLevel))
if err != nil {
log.Fatalf("invalid log level: %s", logLevel)
return fmt.Errorf("invalid log level: %s", logLevel)
}
logrus.SetLevel(level)

return nil
}

// handlePIDFile ensures that the PID file is created and removed properly.
// If the PID file already exists, it returns an error.
// The PID file is automatically removed when the function that called this defers cleanup.
// Returns an error if writing the PID file fails.
func handlePIDFile() error {
if pidFile == "" {
return nil
}

if pid, ok := pidFileExists(pidFile); ok {
return fmt.Errorf("pid file already exists for PID: %d", pid)
}

if err := writePIDFile(pidFile); err != nil {
return fmt.Errorf("failed to write pid file: %w", err)
}

// The caller should defer removePIDFile(pidFile) to ensure cleanup
return nil
}

// ///////////////////
// MAIN
// ///////////////////
func main() {
if err := run(); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

func run() error {
parseFlags()

// Handle PID file
if err := handlePIDFile(); err != nil {
return err
}
if pidFile != "" {
defer removePIDFile(pidFile)
}

// Setup logging
if err := setupLogging(); err != nil {
return fmt.Errorf("failed to setup logging: %w", err)
}

// Initialize DB
if err := initDB(); err != nil {
log.Fatal(err)
return fmt.Errorf("failed to init DB: %w", err)
}

// Daemonize if needed
if daemonize {
if err := daemonizeSelf(); err != nil {
return fmt.Errorf("failed to daemonize: %w", err)
}
}

initMetrics()

// cleanup expired offline messages
// Start offline message cleanup
go func() {
ticker := time.NewTicker(30 * time.Second)
for range ticker.C {
Expand Down Expand Up @@ -585,7 +705,7 @@ func main() {

server := &http.Server{Addr: serverAddr, Handler: mux}

// graceful shutdown
// Graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
go func() {
Expand All @@ -598,5 +718,10 @@ func main() {
}()

logrus.Infof("Server listening on %s", serverAddr)
log.Fatal(server.ListenAndServe())
err := server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("server error: %w", err)
}

return nil
}