@@ -83,24 +83,53 @@ func monitorStream(mon net.Conn, ws *websocket.Conn) {
8383 }()
8484}
8585
86+ func splitOrigin(origin string) (scheme, host, port string, err error) {
87+ parts := strings.SplitN(origin, "://", 2)
88+ if len(parts) != 2 {
89+ return "", "", "", fmt.Errorf("invalid origin format: %s", origin)
90+ }
91+ scheme = parts[0]
92+ hostPort := parts[1]
93+ hostParts := strings.SplitN(hostPort, ":", 2)
94+ host = hostParts[0]
95+ if len(hostParts) == 2 {
96+ port = hostParts[1]
97+ } else {
98+ port = "*"
99+ }
100+ return scheme, host, port, nil
101+ }
102+
86103func checkOrigin(origin string, allowedOrigins []string) bool {
104+ scheme, host, port, err := splitOrigin(origin)
105+ if err != nil {
106+ slog.Error("WebSocket origin check failed", slog.String("origin", origin), slog.String("error", err.Error()))
107+ return false
108+ }
87109 for _, allowed := range allowedOrigins {
88- if strings.HasSuffix(allowed, "*") {
89- // String ends with *, match the prefix
90- if strings.HasPrefix(origin, strings.TrimSuffix(allowed, "*")) {
91- return true
92- }
93- } else {
94- // Exact match
95- if allowed == origin {
96- return true
97- }
110+ allowedScheme, allowedHost, allowedPort, err := splitOrigin(allowed)
111+ if err != nil {
112+ panic(err)
113+ }
114+ if allowedScheme != scheme {
115+ continue
98116 }
117+ if allowedHost != host && allowedHost != "*" {
118+ continue
119+ }
120+ if allowedPort != port && allowedPort != "*" {
121+ continue
122+ }
123+ return true
99124 }
125+ slog.Error("WebSocket origin check failed", slog.String("origin", origin))
100126 return false
101127}
102128
103129func HandleMonitorWS(allowedOrigins []string) http.HandlerFunc {
130+ // Do a dry-run of checkorigin, so it can panic if misconfigured now, not on first request
131+ _ = checkOrigin("http://example.com:8000", allowedOrigins)
132+
104133 upgrader := websocket.Upgrader{
105134 ReadBufferSize: 1024,
106135 WriteBufferSize: 1024,
0 commit comments