diff --git a/.gitignore b/.gitignore index 78f683ee96..d4d10e14fe 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ out/ make/ artifacts/ mikework/ +manifests/ .env out diff --git a/Taskfile.yml b/Taskfile.yml index 7742bd1c25..3253590750 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -8,7 +8,6 @@ vars: BIN_DIR: "bin" VERSION: sh: node version.cjs - RM: '{{if eq OS "windows"}}powershell Remove-Item -Force -ErrorAction SilentlyContinue{{else}}rm -f{{end}}' RMRF: '{{if eq OS "windows"}}powershell Remove-Item -Force -Recurse -ErrorAction SilentlyContinue{{else}}rm -rf{{end}}' DATE: '{{if eq OS "windows"}}powershell Get-Date -UFormat{{else}}date{{end}}' ARTIFACTS_BUCKET: waveterm-github-artifacts/staging-w2 @@ -177,7 +176,7 @@ tasks: desc: Build the wavesrv component for macOS (Darwin) platforms (generates artifacts for both arm64 and amd64). platforms: [darwin] cmds: - - cmd: "{{.RM}} dist/bin/wavesrv*" + - cmd: rm -f dist/bin/wavesrv* ignore_error: true - task: build:server:internal vars: @@ -224,7 +223,7 @@ tasks: desc: Build the wavesrv component for Windows platforms (only generates artifacts for the current architecture). platforms: [windows] cmds: - - cmd: "{{.RM}} dist/bin/wavesrv*" + - cmd: powershell -Command "Remove-Item -Force -ErrorAction SilentlyContinue -Path dist/bin/wavesrv*" ignore_error: true - task: build:server:internal vars: @@ -237,7 +236,7 @@ tasks: desc: Build the wavesrv component for Linux platforms (only generates artifacts for the current architecture). platforms: [linux] cmds: - - cmd: "{{.RM}} dist/bin/wavesrv*" + - cmd: rm -f dist/bin/wavesrv* ignore_error: true - task: build:server:internal vars: @@ -261,7 +260,11 @@ tasks: build:wsh: desc: Build the wsh component for all possible targets. cmds: - - cmd: "{{.RM}} dist/bin/wsh*" + - cmd: rm -f dist/bin/wsh* + platforms: [darwin, linux] + ignore_error: true + - cmd: powershell -Command "Remove-Item -Force -ErrorAction SilentlyContinue -Path dist/bin/wsh*" + platforms: [windows] ignore_error: true - task: build:wsh:internal vars: @@ -527,7 +530,11 @@ tasks: desc: Create package.json for tsunami scaffold using npm commands dir: tsunami/frontend/scaffold cmds: - - cmd: "{{.RM}} package.json" + - cmd: rm -f package.json + platforms: [darwin, linux] + ignore_error: true + - cmd: powershell -Command "Remove-Item -Force -ErrorAction SilentlyContinue -Path package.json" + platforms: [windows] ignore_error: true - npm --no-workspaces init -y --init-license Apache-2.0 - npm pkg set name=tsunami-scaffold @@ -588,7 +595,11 @@ tasks: tsunami:build: desc: Build the tsunami binary. cmds: - - cmd: "{{.RM}} bin/tsunami*" + - cmd: rm -f bin/tsunami* + platforms: [darwin, linux] + ignore_error: true + - cmd: powershell -Command "Remove-Item -Force -ErrorAction SilentlyContinue -Path bin/tsunami*" + platforms: [windows] ignore_error: true - mkdir -p bin - cd tsunami && go build -ldflags "-X main.BuildTime=$({{.DATE}} +'%Y%m%d%H%M') -X main.TsunamiVersion={{.VERSION}}" -o ../bin/tsunami{{exeExt}} cmd/main-tsunami.go diff --git a/cmd/server/main-server.go b/cmd/server/main-server.go index c4c2c14649..a59661f0ba 100644 --- a/cmd/server/main-server.go +++ b/cmd/server/main-server.go @@ -389,11 +389,11 @@ func shutdownActivityUpdate() { func createMainWshClient() { rpc := wshserver.GetMainRpcClient() wshfs.RpcClient = rpc - wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true) + wshutil.DefaultRouter.RegisterTrustedLeaf(rpc, wshutil.DefaultRoute) wps.Broker.SetClient(wshutil.DefaultRouter) - localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local") + localConnWsh := wshutil.MakeWshRpc(wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{}, "conn:local") go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName) - wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true) + wshutil.DefaultRouter.RegisterTrustedLeaf(localConnWsh, wshutil.MakeConnectionRouteId(wshrpc.LocalConnName)) } func grabAndRemoveEnvVars() error { @@ -457,10 +457,12 @@ func maybeStartPprofServer() { } func main() { - log.SetFlags(log.LstdFlags | log.Lmicroseconds) + log.SetFlags(0) // disable timestamp since electron's winston logger already wraps with timestamp log.SetPrefix("[wavesrv] ") wavebase.WaveVersion = WaveVersion wavebase.BuildTime = BuildTime + wshutil.DefaultRouter = wshutil.NewWshRouter() + wshutil.DefaultRouter.SetAsRootRouter() err := grabAndRemoveEnvVars() if err != nil { @@ -546,6 +548,11 @@ func main() { log.Printf("error clearing temp files: %v\n", err) return } + err = wcore.InitMainServer() + if err != nil { + log.Printf("error initializing mainserver: %v\n", err) + return + } err = shellutil.FixupWaveZshHistory() if err != nil { diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 0726ea066e..3fb6a10fcc 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -4,7 +4,7 @@ package cmd import ( - "encoding/json" + "encoding/base64" "fmt" "io" "log" @@ -13,15 +13,16 @@ import ( "path/filepath" "strings" "sync/atomic" - "syscall" "time" "github.com/spf13/cobra" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/util/packetparser" "github.com/wavetermdev/waveterm/pkg/util/sigutil" "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote" @@ -37,11 +38,13 @@ var serverCmd = &cobra.Command{ } var connServerRouter bool -var singleServerRouter bool +var connServerConnName string +var connServerDev bool func init() { serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode") - serverCmd.Flags().BoolVar(&singleServerRouter, "single", false, "run in local single mode") + serverCmd.Flags().StringVar(&connServerConnName, "conn", "", "connection name") + serverCmd.Flags().BoolVar(&connServerDev, "dev", false, "enable dev mode with file logging and PID in logs") rootCmd.AddCommand(serverCmd) } @@ -63,8 +66,11 @@ func MakeRemoteUnixListener() (net.Listener, error) { } func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) { - var routeIdContainer atomic.Pointer[string] - proxy := wshutil.MakeRpcProxy() + defer func() { + panichandler.PanicHandler("handleNewListenerConn", recover()) + }() + var linkIdContainer atomic.Int32 + proxy := wshutil.MakeRpcProxy(fmt.Sprintf("connserver:%s", conn.RemoteAddr().String())) go func() { defer func() { panichandler.PanicHandler("handleNewListenerConn:AdaptOutputChToStream", recover()) @@ -81,31 +87,15 @@ func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) { }() defer func() { conn.Close() - routeIdPtr := routeIdContainer.Load() - if routeIdPtr != nil && *routeIdPtr != "" { - router.UnregisterRoute(*routeIdPtr) - disposeMsg := &wshutil.RpcMessage{ - Command: wshrpc.Command_Dispose, - Data: wshrpc.CommandDisposeData{ - RouteId: *routeIdPtr, - }, - Source: *routeIdPtr, - AuthToken: proxy.GetAuthToken(), - } - disposeBytes, _ := json.Marshal(disposeMsg) - router.InjectMessage(disposeBytes, *routeIdPtr) + linkId := linkIdContainer.Load() + if linkId != baseds.NoLinkId { + router.UnregisterLink(baseds.LinkId(linkId)) } }() wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh) }() - routeId, err := proxy.HandleClientProxyAuth(router) - if err != nil { - log.Printf("error handling client proxy auth: %v\n", err) - conn.Close() - return - } - router.RegisterRoute(routeId, proxy, false) - routeIdContainer.Store(&routeId) + linkId := router.RegisterUntrustedLink(proxy) + linkIdContainer.Store(int32(linkId)) } func runListener(listener net.Listener, router *wshutil.WshRouter) { @@ -127,29 +117,28 @@ func runListener(listener net.Listener, router *wshutil.WshRouter) { } } -func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter, jwtToken string) (*wshutil.WshRpc, error) { - rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken) - if err != nil { - return nil, fmt.Errorf("error extracting rpc context from JWT token: %v", err) +func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) { + routeId := wshutil.MakeConnectionRouteId(connServerConnName) + rpcCtx := wshrpc.RpcContext{ + RouteId: routeId, + Conn: connServerConnName, } - authRtn, err := router.HandleProxyAuth(jwtToken) - if err != nil { - return nil, fmt.Errorf("error handling proxy auth: %v", err) - } - inputCh := make(chan []byte, wshutil.DefaultInputChSize) - outputCh := make(chan []byte, wshutil.DefaultOutputChSize) - connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, authRtn.RouteId) - connServerClient.SetAuthToken(authRtn.AuthToken) - router.RegisterRoute(authRtn.RouteId, connServerClient, false) - wshclient.RouteAnnounceCommand(connServerClient, nil) + connServerClient := wshutil.MakeWshRpc(rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout}, routeId) + router.RegisterTrustedLeaf(connServerClient, routeId) return connServerClient, nil } -func serverRunRouter(jwtToken string) error { +func serverRunRouter() error { + log.Printf("starting connserver router") router := wshutil.NewWshRouter() - termProxy := wshutil.MakeRpcProxy() + termProxy := wshutil.MakeRpcProxy("connserver-term") rawCh := make(chan []byte, wshutil.DefaultOutputChSize) - go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh) + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouter:Parse", recover()) + }() + packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh) + }() go func() { defer func() { panichandler.PanicHandler("serverRunRouter:WritePackets", recover()) @@ -159,63 +148,65 @@ func serverRunRouter(jwtToken string) error { } }() go func() { - // just ignore and drain the rawCh (stdin) - // when stdin is closed, shutdown - defer wshutil.DoShutdown("", 0, true) + defer func() { + panichandler.PanicHandler("serverRunRouter:drainRawCh", recover()) + }() + defer func() { + log.Printf("stdin closed, shutting down") + wshutil.DoShutdown("", 0, true) + }() for range rawCh { // ignore } }() - go func() { - for msg := range termProxy.FromRemoteCh { - // send this to the router - router.InjectMessage(msg, wshutil.UpstreamRoute) - } - }() - router.SetUpstreamClient(termProxy) - // now set up the domain socket - unixListener, err := MakeRemoteUnixListener() - if err != nil { - return fmt.Errorf("cannot create unix listener: %v", err) - } - client, err := setupConnServerRpcClientWithRouter(router, jwtToken) + router.RegisterUpstream(termProxy) + + // setup the connserver rpc client first + client, err := setupConnServerRpcClientWithRouter(router) if err != nil { return fmt.Errorf("error setting up connserver rpc client: %v", err) } wshfs.RpcClient = client - go runListener(unixListener, router) - // run the sysinfo loop - wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn) - select {} -} -func checkForUpdate() error { - remoteInfo := wshutil.GetInfo() - needsRestart, err := wshclient.ConnUpdateWshCommand(RpcClient, remoteInfo, &wshrpc.RpcOpts{Timeout: 60000}) + log.Printf("trying to get JWT public key") + + // fetch and set JWT public key + jwtPublicKeyB64, err := wshclient.GetJwtPublicKeyCommand(client, nil) if err != nil { - return fmt.Errorf("could not update: %w", err) - } - if needsRestart { - // run the restart command here - // how to get the correct path? - return syscall.Exec("~/.waveterm/bin/wsh", []string{"wsh", "connserver", "--single"}, []string{}) + return fmt.Errorf("error getting jwt public key: %v", err) } - return nil -} - -func serverRunSingle(jwtToken string) error { - err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout}, jwtToken) + jwtPublicKeyBytes, err := base64.StdEncoding.DecodeString(jwtPublicKeyB64) if err != nil { - return err + return fmt.Errorf("error decoding jwt public key: %v", err) } - WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn) - err = checkForUpdate() + err = wavejwt.SetPublicKey(jwtPublicKeyBytes) if err != nil { - return err + return fmt.Errorf("error setting jwt public key: %v", err) } - go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn) - select {} // run forever + log.Printf("got JWT public key") + + // now set up the domain socket + unixListener, err := MakeRemoteUnixListener() + if err != nil { + return fmt.Errorf("cannot create unix listener: %v", err) + } + log.Printf("unix listener started") + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouter:runListener", recover()) + }() + runListener(unixListener, router) + }() + // run the sysinfo loop + go func() { + defer func() { + panichandler.PanicHandler("serverRunRouter:RunSysInfoLoop", recover()) + }() + wshremote.RunSysInfoLoop(client, connServerConnName) + }() + log.Printf("running server, successfully started") + select {} } func serverRunNormal(jwtToken string) error { @@ -225,7 +216,12 @@ func serverRunNormal(jwtToken string) error { } wshfs.RpcClient = RpcClient WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn) - go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn) + go func() { + defer func() { + panichandler.PanicHandler("serverRunNormal:RunSysInfoLoop", recover()) + }() + wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn) + }() select {} // run forever } @@ -250,22 +246,53 @@ func askForJwtToken() (string, error) { } func serverRun(cmd *cobra.Command, args []string) error { + var logFile *os.File + if connServerDev { + var err error + logFile, err = os.OpenFile("/tmp/connserver.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open log file: %v\n", err) + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + log.SetPrefix(fmt.Sprintf("[PID:%d] ", os.Getpid())) + } else { + defer logFile.Close() + logWriter := io.MultiWriter(os.Stderr, logFile) + log.SetOutput(logWriter) + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + log.SetPrefix(fmt.Sprintf("[PID:%d] ", os.Getpid())) + } + } + if connServerConnName == "" { + if logFile != nil { + fmt.Fprintf(logFile, "--conn parameter is required\n") + } + return fmt.Errorf("--conn parameter is required") + } installErr := wshutil.InstallRcFiles() if installErr != nil { + if logFile != nil { + fmt.Fprintf(logFile, "error installing rc files: %v\n", installErr) + } log.Printf("error installing rc files: %v", installErr) } + sigutil.InstallSIGUSR1Handler() + if connServerRouter { + err := serverRunRouter() + if err != nil && logFile != nil { + fmt.Fprintf(logFile, "serverRunRouter error: %v\n", err) + } + return err + } jwtToken, err := askForJwtToken() if err != nil { + if logFile != nil { + fmt.Fprintf(logFile, "askForJwtToken error: %v\n", err) + } return err } - - sigutil.InstallSIGUSR1Handler() - - if singleServerRouter { - return serverRunSingle(jwtToken) - } else if connServerRouter { - return serverRunRouter(jwtToken) - } else { - return serverRunNormal(jwtToken) + err = serverRunNormal(jwtToken) + if err != nil && logFile != nil { + fmt.Fprintf(logFile, "serverRunNormal error: %v\n", err) } + return err } diff --git a/cmd/wsh/cmd/wshcmd-createblock.go b/cmd/wsh/cmd/wshcmd-createblock.go index 05ed221194..aaf153e232 100644 --- a/cmd/wsh/cmd/wshcmd-createblock.go +++ b/cmd/wsh/cmd/wshcmd-createblock.go @@ -34,12 +34,17 @@ func createBlockRun(cmd *cobra.Command, args []string) error { if len(args) > 1 { metaSetStrs = args[1:] } + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } meta, err := parseMetaSets(metaSetStrs) if err != nil { return err } meta["view"] = viewName data := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: meta, }, diff --git a/cmd/wsh/cmd/wshcmd-debug.go b/cmd/wsh/cmd/wshcmd-debug.go index e28f5df177..9efac0ff87 100644 --- a/cmd/wsh/cmd/wshcmd-debug.go +++ b/cmd/wsh/cmd/wshcmd-debug.go @@ -31,33 +31,12 @@ var debugSendTelemetryCmd = &cobra.Command{ Hidden: true, } -var debugGetTabCmd = &cobra.Command{ - Use: "gettab", - Short: "get tab", - RunE: debugGetTabRun, - Hidden: true, -} - func init() { debugCmd.AddCommand(debugBlockIdsCmd) debugCmd.AddCommand(debugSendTelemetryCmd) - debugCmd.AddCommand(debugGetTabCmd) rootCmd.AddCommand(debugCmd) } -func debugGetTabRun(cmd *cobra.Command, args []string) error { - tab, err := wshclient.GetTabCommand(RpcClient, RpcContext.TabId, nil) - if err != nil { - return err - } - barr, err := json.MarshalIndent(tab, "", " ") - if err != nil { - return err - } - WriteStdout("%s\n", string(barr)) - return nil -} - func debugSendTelemetryRun(cmd *cobra.Command, args []string) error { err := wshclient.SendTelemetryCommand(RpcClient, nil) return err diff --git a/cmd/wsh/cmd/wshcmd-editconfig.go b/cmd/wsh/cmd/wshcmd-editconfig.go index 6dc9c13f6a..cbd4015bae 100644 --- a/cmd/wsh/cmd/wshcmd-editconfig.go +++ b/cmd/wsh/cmd/wshcmd-editconfig.go @@ -38,7 +38,13 @@ func editConfigRun(cmd *cobra.Command, args []string) (rtnErr error) { configFile = args[0] } + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + wshCmd := &wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]interface{}{ waveobj.MetaKey_View: "waveconfig", diff --git a/cmd/wsh/cmd/wshcmd-editor.go b/cmd/wsh/cmd/wshcmd-editor.go index 670011c7a0..4968b17509 100644 --- a/cmd/wsh/cmd/wshcmd-editor.go +++ b/cmd/wsh/cmd/wshcmd-editor.go @@ -54,7 +54,14 @@ func editorRun(cmd *cobra.Command, args []string) (rtnErr error) { if err != nil { return fmt.Errorf("getting file info: %w", err) } + + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + wshCmd := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]any{ waveobj.MetaKey_View: "preview", diff --git a/cmd/wsh/cmd/wshcmd-launch.go b/cmd/wsh/cmd/wshcmd-launch.go index 679854e72f..3ec582a6cd 100644 --- a/cmd/wsh/cmd/wshcmd-launch.go +++ b/cmd/wsh/cmd/wshcmd-launch.go @@ -48,8 +48,14 @@ func launchRun(cmd *cobra.Command, args []string) (rtnErr error) { } } + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + // Create block data from widget config createBlockData := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &widget.BlockDef, Magnified: magnifyBlock || widget.Magnified, Focused: true, diff --git a/cmd/wsh/cmd/wshcmd-root.go b/cmd/wsh/cmd/wshcmd-root.go index 18ef5c8cff..0edbcea79c 100644 --- a/cmd/wsh/cmd/wshcmd-root.go +++ b/cmd/wsh/cmd/wshcmd-root.go @@ -103,15 +103,6 @@ func getIsTty() bool { return false } -func getThisBlockMeta() (waveobj.MetaMapType, error) { - blockORef := waveobj.ORef{OType: waveobj.OType_Block, OID: RpcContext.BlockId} - resp, err := wshclient.GetMetaCommand(RpcClient, wshrpc.CommandGetMetaData{ORef: blockORef}, &wshrpc.RpcOpts{Timeout: 2000}) - if err != nil { - return nil, fmt.Errorf("getting metadata: %w", err) - } - return resp, nil -} - type RunEFnType = func(*cobra.Command, []string) error func activityWrap(activityStr string, origRunE RunEFnType) RunEFnType { @@ -141,18 +132,18 @@ func setupRpcClientWithToken(swapTokenStr string) (wshrpc.CommandAuthenticateRtn if err != nil { return rtn, fmt.Errorf("error unpacking token: %w", err) } - if token.SockName == "" { - return rtn, fmt.Errorf("no sockname in token") - } if token.RpcContext == nil { return rtn, fmt.Errorf("no rpccontext in token") } + if token.RpcContext.SockName == "" { + return rtn, fmt.Errorf("no sockname in token") + } RpcContext = *token.RpcContext - RpcClient, err = wshutil.SetupDomainSocketRpcClient(token.SockName, nil, "wshcmd") + RpcClient, err = wshutil.SetupDomainSocketRpcClient(token.RpcContext.SockName, nil, "wshcmd") if err != nil { return rtn, fmt.Errorf("error setting up domain socket rpc client: %w", err) } - return wshclient.AuthenticateTokenCommand(RpcClient, wshrpc.CommandAuthenticateTokenData{Token: token.Token}, nil) + return wshclient.AuthenticateTokenCommand(RpcClient, wshrpc.CommandAuthenticateTokenData{Token: token.Token}, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) } // returns the wrapped stdin and a new rpc client (that wraps the stdin input and stdout output) @@ -170,7 +161,15 @@ func setupRpcClient(serverImpl wshutil.ServerImpl, jwtToken string) error { if err != nil { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } - wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.RpcOpts{NoResponse: true}) + _, err = wshclient.AuthenticateCommand(RpcClient, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + if err != nil { + return fmt.Errorf("error authenticating: %v", err) + } + blockId := os.Getenv("WAVETERM_BLOCKID") + if blockId != "" { + peerInfo := fmt.Sprintf("domain:block:%s", blockId) + wshclient.SetPeerInfoCommand(RpcClient, peerInfo, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) + } // note we don't modify WrappedStdin here (just use os.Stdin) return nil } @@ -188,7 +187,14 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) { } return &orefObj, nil } - rtnData, err := wshclient.ResolveIdsCommand(RpcClient, wshrpc.CommandResolveIdsData{Ids: []string{id}}, &wshrpc.RpcOpts{Timeout: 2000}) + blockId := os.Getenv("WAVETERM_BLOCKID") + if blockId == "" { + return nil, fmt.Errorf("no WAVETERM_BLOCKID env var set") + } + rtnData, err := wshclient.ResolveIdsCommand(RpcClient, wshrpc.CommandResolveIdsData{ + BlockId: blockId, + Ids: []string{id}, + }, &wshrpc.RpcOpts{Timeout: 2000}) if err != nil { return nil, fmt.Errorf("error resolving ids: %v", err) } @@ -199,6 +205,10 @@ func resolveSimpleId(id string) (*waveobj.ORef, error) { return &oref, nil } +func getTabIdFromEnv() string { + return os.Getenv("WAVETERM_TABID") +} + // this will send wsh activity to the client running on *your* local machine (it does not contact any wave cloud infrastructure) // if you've turned off telemetry in your local client, this data never gets sent to us // no parameters or timestamps are sent, as you can see below, it just sends the name of the command (and if there was an error) diff --git a/cmd/wsh/cmd/wshcmd-run.go b/cmd/wsh/cmd/wshcmd-run.go index 783e87fd6b..6faf424c99 100644 --- a/cmd/wsh/cmd/wshcmd-run.go +++ b/cmd/wsh/cmd/wshcmd-run.go @@ -133,7 +133,13 @@ func runRun(cmd *cobra.Command, args []string) (rtnErr error) { createMeta[waveobj.MetaKey_Connection] = RpcContext.Conn } + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + createBlockData := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: createMeta, Files: map[string]*waveobj.FileDef{ diff --git a/cmd/wsh/cmd/wshcmd-secret.go b/cmd/wsh/cmd/wshcmd-secret.go index 7d555c0dec..916e3ae4a5 100644 --- a/cmd/wsh/cmd/wshcmd-secret.go +++ b/cmd/wsh/cmd/wshcmd-secret.go @@ -176,7 +176,13 @@ func secretUiRun(cmd *cobra.Command, args []string) (rtnErr error) { sendActivity("secret", rtnErr == nil) }() + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + wshCmd := &wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]interface{}{ waveobj.MetaKey_View: "waveconfig", diff --git a/cmd/wsh/cmd/wshcmd-ssh.go b/cmd/wsh/cmd/wshcmd-ssh.go index e8859cf771..25dad3c098 100644 --- a/cmd/wsh/cmd/wshcmd-ssh.go +++ b/cmd/wsh/cmd/wshcmd-ssh.go @@ -54,6 +54,11 @@ func sshRun(cmd *cobra.Command, args []string) (rtnErr error) { wshclient.ConnConnectCommand(RpcClient, connOpts, &wshrpc.RpcOpts{Timeout: 60000}) if newBlock { + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + // Create a new block with the SSH connection createMeta := map[string]any{ waveobj.MetaKey_View: "term", @@ -64,6 +69,7 @@ func sshRun(cmd *cobra.Command, args []string) (rtnErr error) { createMeta[waveobj.MetaKey_Connection] = RpcContext.Conn } createBlockData := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: createMeta, }, diff --git a/cmd/wsh/cmd/wshcmd-term.go b/cmd/wsh/cmd/wshcmd-term.go index ec1ad6a8c5..f2119ad5b7 100644 --- a/cmd/wsh/cmd/wshcmd-term.go +++ b/cmd/wsh/cmd/wshcmd-term.go @@ -55,6 +55,12 @@ func termRun(cmd *cobra.Command, args []string) (rtnErr error) { if err != nil { return fmt.Errorf("getting absolute path: %w", err) } + + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + createMeta := map[string]any{ waveobj.MetaKey_View: "term", waveobj.MetaKey_CmdCwd: cwd, @@ -64,6 +70,7 @@ func termRun(cmd *cobra.Command, args []string) (rtnErr error) { createMeta[waveobj.MetaKey_Connection] = RpcContext.Conn } createBlockData := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: createMeta, }, diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index b0aafe148f..1ba84b516f 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -53,11 +53,16 @@ func viewRun(cmd *cobra.Command, args []string) (rtnErr error) { OutputHelpMessage(cmd) return fmt.Errorf("too many arguments. wsh %s requires exactly one argument", cmdName) } + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } fileArg := args[0] conn := RpcContext.Conn var wshCmd *wshrpc.CommandCreateBlockData if strings.HasPrefix(fileArg, "http://") || strings.HasPrefix(fileArg, "https://") { wshCmd = &wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]any{ waveobj.MetaKey_View: "web", @@ -84,6 +89,7 @@ func viewRun(cmd *cobra.Command, args []string) (rtnErr error) { return fmt.Errorf("getting file info: %w", err) } wshCmd = &wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]interface{}{ waveobj.MetaKey_View: "preview", diff --git a/cmd/wsh/cmd/wshcmd-wavepath.go b/cmd/wsh/cmd/wshcmd-wavepath.go index b7489d9842..9a5ad6af39 100644 --- a/cmd/wsh/cmd/wshcmd-wavepath.go +++ b/cmd/wsh/cmd/wshcmd-wavepath.go @@ -56,10 +56,16 @@ func wavepathRun(cmd *cobra.Command, args []string) (rtnErr error) { open, _ := cmd.Flags().GetBool("open") openExternal, _ := cmd.Flags().GetBool("open-external") + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + path, err := wshclient.PathCommand(RpcClient, wshrpc.PathCommandData{ PathType: pathType, Open: open, OpenExternal: openExternal, + TabId: tabId, }, nil) if err != nil { return fmt.Errorf("getting path: %w", err) diff --git a/cmd/wsh/cmd/wshcmd-web.go b/cmd/wsh/cmd/wshcmd-web.go index 9ad3fee486..bfda76b82c 100644 --- a/cmd/wsh/cmd/wshcmd-web.go +++ b/cmd/wsh/cmd/wshcmd-web.go @@ -111,7 +111,14 @@ func webOpenRun(cmd *cobra.Command, args []string) (rtnErr error) { if replaceBlockORef != nil && webOpenMagnified { return fmt.Errorf("cannot use --replace and --magnified together") } + + tabId := getTabIdFromEnv() + if tabId == "" { + return fmt.Errorf("no WAVETERM_TABID env var set") + } + wshCmd := wshrpc.CommandCreateBlockData{ + TabId: tabId, BlockDef: &waveobj.BlockDef{ Meta: map[string]any{ waveobj.MetaKey_View: "web", diff --git a/db/migrations-wstore/000009_mainserver.down.sql b/db/migrations-wstore/000009_mainserver.down.sql new file mode 100644 index 0000000000..1b3a3329f0 --- /dev/null +++ b/db/migrations-wstore/000009_mainserver.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS db_mainserver; diff --git a/db/migrations-wstore/000009_mainserver.up.sql b/db/migrations-wstore/000009_mainserver.up.sql new file mode 100644 index 0000000000..f025565364 --- /dev/null +++ b/db/migrations-wstore/000009_mainserver.up.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS db_mainserver ( + oid varchar(36) PRIMARY KEY, + version int NOT NULL, + data json NOT NULL +); diff --git a/emain/emain-menu.ts b/emain/emain-menu.ts index 27b438c4a4..f4d45f8639 100644 --- a/emain/emain-menu.ts +++ b/emain/emain-menu.ts @@ -384,10 +384,12 @@ export function makeAndSetAppMenu() { }); } -waveEventSubscribe({ - eventType: "workspace:update", - handler: makeAndSetAppMenu, -}); +function initMenuEventSubscriptions() { + waveEventSubscribe({ + eventType: "workspace:update", + handler: makeAndSetAppMenu, + }); +} function getWebContentsByWorkspaceOrBuilderId(workspaceOrBuilderId: string): electron.WebContents { const ww = getWaveWindowByWorkspaceId(workspaceOrBuilderId); @@ -495,4 +497,4 @@ function makeDockTaskbar() { } } -export { makeDockTaskbar }; +export { initMenuEventSubscriptions, makeDockTaskbar }; diff --git a/emain/emain.ts b/emain/emain.ts index 093a74bd41..c8b0cfee28 100644 --- a/emain/emain.ts +++ b/emain/emain.ts @@ -23,7 +23,7 @@ import { } from "./emain-activity"; import { initIpcHandlers } from "./emain-ipc"; import { log } from "./emain-log"; -import { makeAndSetAppMenu, makeDockTaskbar } from "./emain-menu"; +import { initMenuEventSubscriptions, makeAndSetAppMenu, makeDockTaskbar } from "./emain-menu"; import { checkIfRunningUnderARM64Translation, getElectronAppBasePath, @@ -350,6 +350,7 @@ async function appMain() { try { initElectronWshClient(); initElectronWshrpc(ElectronWshClient, { authKey: AuthKey }); + initMenuEventSubscriptions(); } catch (e) { console.log("error initializing wshrpc", e); } diff --git a/frontend/app/store/wps.ts b/frontend/app/store/wps.ts index 605987e2d8..b28ec94c8d 100644 --- a/frontend/app/store/wps.ts +++ b/frontend/app/store/wps.ts @@ -1,9 +1,16 @@ // Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 +import type { WshClient } from "@/app/store/wshclient"; +import { RpcApi } from "@/app/store/wshclientapi"; import { isBlank } from "@/util/util"; import { Subject } from "rxjs"; -import { sendRawRpcMessage } from "./ws"; + +let WpsRpcClient: WshClient; + +function setWpsRpcClient(client: WshClient) { + WpsRpcClient = client; +} type WaveEventSubject = { handler: (event: WaveEvent) => void; @@ -33,12 +40,13 @@ function wpsReconnectHandler() { } } -function makeWaveReSubCommand(eventType: string): RpcMessage { - let subjects = waveEventSubjects.get(eventType); +function updateWaveEventSub(eventType: string) { + const subjects = waveEventSubjects.get(eventType); if (subjects == null) { - return { command: "eventunsub", data: eventType }; + RpcApi.EventUnsubCommand(WpsRpcClient, eventType, { noresponse: true }); + return; } - let subreq: SubscriptionRequest = { event: eventType, scopes: [], allscopes: false }; + const subreq: SubscriptionRequest = { event: eventType, scopes: [], allscopes: false }; for (const scont of subjects) { if (isBlank(scont.scope)) { subreq.allscopes = true; @@ -47,13 +55,7 @@ function makeWaveReSubCommand(eventType: string): RpcMessage { } subreq.scopes.push(scont.scope); } - return { command: "eventsub", data: subreq }; -} - -function updateWaveEventSub(eventType: string) { - const command = makeWaveReSubCommand(eventType); - // console.log("updateWaveEventSub", eventType, command); - sendRawRpcMessage(command); + RpcApi.EventSubCommand(WpsRpcClient, subreq, { noresponse: true }); } function waveEventSubscribe(...subscriptions: WaveEventSubscription[]): () => void { @@ -143,4 +145,11 @@ function handleWaveEvent(event: WaveEvent) { } } -export { getFileSubject, handleWaveEvent, waveEventSubscribe, waveEventUnsubscribe, wpsReconnectHandler }; +export { + getFileSubject, + handleWaveEvent, + setWpsRpcClient, + waveEventSubscribe, + waveEventUnsubscribe, + wpsReconnectHandler, +}; diff --git a/frontend/app/store/ws.ts b/frontend/app/store/ws.ts index 54a739a22c..c7b260ca4e 100644 --- a/frontend/app/store/ws.ts +++ b/frontend/app/store/ws.ts @@ -19,7 +19,7 @@ function addWSReconnectHandler(handler: () => void) { } function removeWSReconnectHandler(handler: () => void) { - const index = this.reconnectHandlers.indexOf(handler); + const index = reconnectHandlers.indexOf(handler); if (index > -1) { reconnectHandlers.splice(index, 1); } @@ -37,7 +37,7 @@ class WSControl { opening: boolean = false; reconnectTimes: number = 0; msgQueue: any[] = []; - routeId: string; + stableId: string; messageCallback: WSEventCallback; watchSessionId: string = null; watchScreenId: string = null; @@ -50,13 +50,13 @@ class WSControl { constructor( baseHostPort: string, - routeId: string, + stableId: string, messageCallback: WSEventCallback, electronOverrideOpts?: ElectronOverrideOpts ) { this.baseHostPort = baseHostPort; this.messageCallback = messageCallback; - this.routeId = routeId; + this.stableId = stableId; this.open = false; this.eoOpts = electronOverrideOpts; setInterval(this.sendPing.bind(this), 5000); @@ -75,7 +75,7 @@ class WSControl { dlog("try reconnect:", desc); this.opening = true; this.wsConn = newWebSocket( - this.baseHostPort + "/ws?routeid=" + this.routeId, + this.baseHostPort + "/ws?stableid=" + encodeURIComponent(this.stableId), this.eoOpts ? { [AuthKeyHeader]: this.eoOpts.authKey, @@ -221,6 +221,12 @@ class WSControl { pushMessage(data: WSCommandType) { if (!this.open) { + if (data.wscommand === "rpc" && data.message) { + const cmd = data.message.command; + if (cmd === "routeannounce" || cmd === "routeunannounce") { + return; + } + } this.msgQueue.push(data); return; } @@ -231,11 +237,11 @@ class WSControl { let globalWS: WSControl; function initGlobalWS( baseHostPort: string, - routeId: string, + stableId: string, messageCallback: WSEventCallback, electronOverrideOpts?: ElectronOverrideOpts ) { - globalWS = new WSControl(baseHostPort, routeId, messageCallback, electronOverrideOpts); + globalWS = new WSControl(baseHostPort, stableId, messageCallback, electronOverrideOpts); } function sendRawRpcMessage(msg: RpcMessage) { diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 4b7ea1b53f..b503b43476 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -27,6 +27,11 @@ class RpcApiType { return client.wshRpcCall("authenticatetoken", data, opts); } + // command "authenticatetokenverify" [call] + AuthenticateTokenVerifyCommand(client: WshClient, data: CommandAuthenticateTokenData, opts?: RpcOpts): Promise { + return client.wshRpcCall("authenticatetokenverify", data, opts); + } + // command "blockinfo" [call] BlockInfoCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { return client.wshRpcCall("blockinfo", data, opts); @@ -307,6 +312,11 @@ class RpcApiType { return client.wshRpcCall("getfullconfig", null, opts); } + // command "getjwtpublickey" [call] + GetJwtPublicKeyCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("getjwtpublickey", null, opts); + } + // command "getmeta" [call] GetMetaCommand(client: WshClient, data: CommandGetMetaData, opts?: RpcOpts): Promise { return client.wshRpcCall("getmeta", data, opts); @@ -537,6 +547,11 @@ class RpcApiType { return client.wshRpcCall("setmeta", data, opts); } + // command "setpeerinfo" [call] + SetPeerInfoCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("setpeerinfo", data, opts); + } + // command "setrtinfo" [call] SetRTInfoCommand(client: WshClient, data: CommandSetRTInfoData, opts?: RpcOpts): Promise { return client.wshRpcCall("setrtinfo", data, opts); @@ -552,11 +567,6 @@ class RpcApiType { return client.wshRpcCall("setvar", data, opts); } - // command "setview" [call] - SetViewCommand(client: WshClient, data: CommandBlockSetViewData, opts?: RpcOpts): Promise { - return client.wshRpcCall("setview", data, opts); - } - // command "startbuilder" [call] StartBuilderCommand(client: WshClient, data: CommandStartBuilderData, opts?: RpcOpts): Promise { return client.wshRpcCall("startbuilder", data, opts); diff --git a/frontend/app/store/wshrouter.ts b/frontend/app/store/wshrouter.ts index b70031695c..0649be1f18 100644 --- a/frontend/app/store/wshrouter.ts +++ b/frontend/app/store/wshrouter.ts @@ -8,6 +8,7 @@ import debug from "debug"; const dlog = debug("wave:router"); const SysRouteName = "sys"; +const ControlRouteName = "$control"; type RouteInfo = { rpcId: string; @@ -47,6 +48,7 @@ class WshRouter { command: "routeannounce", data: routeId, source: routeId, + route: ControlRouteName, }; this.upstreamClient.recvRpcMessage(announceMsg); } @@ -135,6 +137,7 @@ class WshRouter { command: "routeannounce", data: routeId, source: routeId, + route: ControlRouteName, }; this.upstreamClient.recvRpcMessage(announceMsg); this.routeMap.set(routeId, client); @@ -147,6 +150,7 @@ class WshRouter { command: "routeunannounce", data: routeId, source: routeId, + route: ControlRouteName, }; this.upstreamClient?.recvRpcMessage(unannounceMsg); this.routeMap.delete(routeId); diff --git a/frontend/app/store/wshrpcutil-base.ts b/frontend/app/store/wshrpcutil-base.ts index 987fa2115e..5565416478 100644 --- a/frontend/app/store/wshrpcutil-base.ts +++ b/frontend/app/store/wshrpcutil-base.ts @@ -1,7 +1,7 @@ // Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -import { wpsReconnectHandler } from "@/app/store/wps"; +import { setWpsRpcClient, wpsReconnectHandler } from "@/app/store/wps"; import { WshClient } from "@/app/store/wshclient"; import { WshRouter } from "@/app/store/wshrouter"; import { getWSServerEndpoint } from "@/util/endpoints"; @@ -118,6 +118,7 @@ function initElectronWshrpc(electronClient: WshClient, eoOpts: ElectronOverrideO }; initGlobalWS(getWSServerEndpoint(), "electron", handleFn, eoOpts); globalWS.connectNow("connectWshrpc"); + setWpsRpcClient(electronClient); DefaultRouter.registerRoute(electronClient.routeId, electronClient); addWSReconnectHandler(() => { DefaultRouter.reannounceRoutes(); diff --git a/frontend/app/store/wshrpcutil.ts b/frontend/app/store/wshrpcutil.ts index 8ed4361802..5a5b1c6d69 100644 --- a/frontend/app/store/wshrpcutil.ts +++ b/frontend/app/store/wshrpcutil.ts @@ -1,7 +1,7 @@ // Copyright 2025, Command Line Inc. // SPDX-License-Identifier: Apache-2.0 -import { wpsReconnectHandler } from "@/app/store/wps"; +import { setWpsRpcClient, wpsReconnectHandler } from "@/app/store/wps"; import { TabClient } from "@/app/store/tabrpcclient"; import { WshRouter } from "@/app/store/wshrouter"; import { getWSServerEndpoint } from "@/util/endpoints"; @@ -19,6 +19,7 @@ function initWshrpc(routeId: string): WSControl { initGlobalWS(getWSServerEndpoint(), routeId, handleFn); globalWS.connectNow("connectWshrpc"); TabRpcClient = new TabClient(routeId); + setWpsRpcClient(TabRpcClient); DefaultRouter.registerRoute(TabRpcClient.routeId, TabRpcClient); addWSReconnectHandler(() => { DefaultRouter.reannounceRoutes(); diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 8a91d4d571..651e655222 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -196,10 +196,9 @@ declare global { // wshrpc.CommandAuthenticateRtnData type CommandAuthenticateRtnData = { - routeid: string; - authtoken?: string; env?: {[key: string]: string}; initscripttext?: string; + rpccontext?: RpcContext; }; // wshrpc.CommandAuthenticateTokenData @@ -215,12 +214,6 @@ declare global { termsize?: TermSize; }; - // wshrpc.CommandBlockSetViewData - type CommandBlockSetViewData = { - blockid: string; - view: string; - }; - // wshrpc.CommandCaptureBlockScreenshotData type CommandCaptureBlockScreenshotData = { blockid: string; @@ -378,7 +371,6 @@ declare global { // wshrpc.CommandMessageData type CommandMessageData = { - oref: ORef; message: string; }; @@ -1007,6 +999,15 @@ declare global { buildoutput: string; }; + // wshrpc.RpcContext + type RpcContext = { + sockname?: string; + routeid: string; + blockid?: string; + conn?: string; + isrouter?: boolean; + }; + // wshutil.RpcMessage type RpcMessage = { command?: string; @@ -1014,7 +1015,6 @@ declare global { resid?: string; timeout?: number; route?: string; - authtoken?: string; source?: string; cont?: boolean; cancel?: boolean; diff --git a/pkg/baseds/baseds.go b/pkg/baseds/baseds.go new file mode 100644 index 0000000000..8ede225968 --- /dev/null +++ b/pkg/baseds/baseds.go @@ -0,0 +1,14 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +// used for shared datastructures +package baseds + +type LinkId int32 + +const NoLinkId = 0 + +type RpcInputChType struct { + MsgBytes []byte + IngressLinkId LinkId +} diff --git a/pkg/blockcontroller/shellcontroller.go b/pkg/blockcontroller/shellcontroller.go index 5411c40404..bbdf38295d 100644 --- a/pkg/blockcontroller/shellcontroller.go +++ b/pkg/blockcontroller/shellcontroller.go @@ -420,12 +420,16 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc } } else { sockName := wslConn.GetDomainSocketName() - rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()} - jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) + rpcContext := wshrpc.RpcContext{ + RouteId: wshutil.MakeRandomProcRouteId(), + SockName: sockName, + BlockId: bc.BlockId, + Conn: wslConn.GetName(), + } + jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { return nil, fmt.Errorf("error making jwt token: %w", err) } - swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn) @@ -449,12 +453,16 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc } } else { sockName := conn.GetDomainSocketName() - rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: conn.Opts.String()} - jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) + rpcContext := wshrpc.RpcContext{ + RouteId: wshutil.MakeRandomProcRouteId(), + SockName: sockName, + BlockId: bc.BlockId, + Conn: conn.Opts.String(), + } + jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { return nil, fmt.Errorf("error making jwt token: %w", err) } - swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr shellProc, err = shellexec.StartRemoteShellProc(ctx, logCtx, rc.TermSize, cmdStr, cmdOpts, conn) @@ -472,12 +480,15 @@ func (bc *ShellController) setupAndStartShellProcess(logCtx context.Context, rc } else if connUnion.ConnType == ConnType_Local { if connUnion.WshEnabled { sockName := wavebase.GetDomainSocketName() - rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId} - jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName) + rpcContext := wshrpc.RpcContext{ + RouteId: wshutil.MakeRandomProcRouteId(), + SockName: sockName, + BlockId: bc.BlockId, + } + jwtStr, err := wshutil.MakeClientJWTToken(rpcContext) if err != nil { return nil, fmt.Errorf("error making jwt token: %w", err) } - swapToken.SockName = sockName swapToken.RpcContext = &rpcContext swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr } @@ -504,9 +515,11 @@ func (bc *ShellController) manageRunningShellProcess(shellProc *shellexec.ShellP // make esc sequence wshclient wshProxy // we don't need to authenticate this wshProxy since it is coming direct - wshProxy := wshutil.MakeRpcProxy() - wshProxy.SetRpcContext(&wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId}) - wshutil.DefaultRouter.RegisterRoute(wshutil.MakeControllerRouteId(bc.BlockId), wshProxy, true) + wshProxy := wshutil.MakeRpcProxy(fmt.Sprintf("controller:%s", bc.BlockId)) + controllerLinkId, err := wshutil.DefaultRouter.RegisterTrustedLeaf(wshProxy, wshutil.MakeControllerRouteId(bc.BlockId)) + if err != nil { + return fmt.Errorf("cannot register controller route: %w", err) + } ptyBuffer := wshutil.MakePtyBuffer(wshutil.WaveOSCPrefix, shellProc.Cmd, wshProxy.FromRemoteCh) go func() { // handles regular output from the pty (goes to the blockfile and xterm) @@ -584,7 +597,7 @@ func (bc *ShellController) manageRunningShellProcess(shellProc *shellexec.ShellP // wait for the shell to finish var exitCode int defer func() { - wshutil.DefaultRouter.UnregisterRoute(wshutil.MakeControllerRouteId(bc.BlockId)) + wshutil.DefaultRouter.UnregisterLink(controllerLinkId) bc.UpdateControllerAndSendUpdate(func() bool { if bc.ProcStatus == Status_Running { bc.ProcStatus = Status_Done diff --git a/pkg/remote/conncontroller/conncontroller.go b/pkg/remote/conncontroller/conncontroller.go index b5d8779a58..998d71a97d 100644 --- a/pkg/remote/conncontroller/conncontroller.go +++ b/pkg/remote/conncontroller/conncontroller.go @@ -74,7 +74,7 @@ type SSHConn struct { var ConnServerCmdTemplate = strings.TrimSpace( strings.Join([]string{ "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm; exit 0);", - "exec %s connserver", + "exec %s connserver --conn %s %s", }, "\n")) func IsLocalConnName(connName string) bool { @@ -284,12 +284,13 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo } client := conn.GetClient() wshPath := conn.getWshPath() + sockName := conn.GetDomainSocketName() rpcCtx := wshrpc.RpcContext{ - ClientType: wshrpc.ClientType_ConnServer, - Conn: conn.GetName(), + RouteId: wshutil.MakeConnectionRouteId(conn.GetName()), + SockName: sockName, + Conn: conn.GetName(), } - sockName := conn.GetDomainSocketName() - jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) + jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx) if err != nil { return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) } @@ -305,7 +306,11 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo if err != nil { return false, "", "", fmt.Errorf("unable to get stdin pipe: %w", err) } - cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) + devFlag := "" + if wavebase.IsDevMode() { + devFlag = "--dev" + } + cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag) log.Printf("starting conn controller: %q\n", cmdStr) shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr)) blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) diff --git a/pkg/shellexec/shellexec.go b/pkg/shellexec/shellexec.go index 597aaab56a..d08250a4d8 100644 --- a/pkg/shellexec/shellexec.go +++ b/pkg/shellexec/shellexec.go @@ -265,7 +265,7 @@ func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr st cmdCombined = fmt.Sprintf(`%s=%s %s`, wavebase.WaveSwapTokenVarName, packedToken, cmdCombined) } jwtToken := cmdOpts.SwapToken.Env[wavebase.WaveJwtTokenVarName] - if jwtToken != "" { + if jwtToken != "" && cmdOpts.ForceJwt { conn.Debugf(ctx, "adding JWT token to environment\n") cmdCombined = fmt.Sprintf(`%s=%s %s`, wavebase.WaveJwtTokenVarName, jwtToken, cmdCombined) } diff --git a/pkg/tsgen/tsgen.go b/pkg/tsgen/tsgen.go index 5c223b298c..b152db83fc 100644 --- a/pkg/tsgen/tsgen.go +++ b/pkg/tsgen/tsgen.go @@ -499,6 +499,9 @@ func GenerateWaveObjTypes(tsTypesMap map[reflect.Type]string) { GenerateTSType(reflect.TypeOf(extraType), tsTypesMap) } for _, rtype := range waveobj.AllWaveObjTypes() { + if rtype.String() == "*waveobj.MainServer" { + continue + } GenerateTSType(rtype, tsTypesMap) } } diff --git a/pkg/util/packetparser/packetparser.go b/pkg/util/packetparser/packetparser.go index a44b6bde5f..14b9884845 100644 --- a/pkg/util/packetparser/packetparser.go +++ b/pkg/util/packetparser/packetparser.go @@ -10,6 +10,7 @@ import ( "io" "log" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" ) @@ -18,7 +19,7 @@ type PacketParser struct { Ch chan []byte } -func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan []byte, rawCh chan []byte) { +func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan baseds.RpcInputChType, rawCh chan []byte) { defer close(packetCh) defer close(rawCh) for { @@ -37,14 +38,14 @@ func ParseWithLinesChan(input chan utilfn.LineOutput, packetCh chan []byte, rawC } if bytes.HasPrefix([]byte(line.Line), []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix([]byte(line.Line), []byte{'}'}) { // strip off the leading "##" - packetCh <- []byte(line.Line[3:len(line.Line)]) + packetCh <- baseds.RpcInputChType{MsgBytes: []byte(line.Line[3:len(line.Line)])} } else { rawCh <- []byte(line.Line) } } } -func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error { +func Parse(input io.Reader, packetCh chan baseds.RpcInputChType, rawCh chan []byte) error { bufReader := bufio.NewReader(input) defer close(packetCh) defer close(rawCh) @@ -63,7 +64,7 @@ func Parse(input io.Reader, packetCh chan []byte, rawCh chan []byte) error { } if bytes.HasPrefix(line, []byte{'#', '#', 'N', '{'}) && bytes.HasSuffix(line, []byte{'}', '\n'}) { // strip off the leading "##" and trailing "\n" (single byte) - packetCh <- line[3 : len(line)-1] + packetCh <- baseds.RpcInputChType{MsgBytes: line[3 : len(line)-1]} } else { rawCh <- line } diff --git a/pkg/util/shellutil/shellquote.go b/pkg/util/shellutil/shellquote.go index faa430d321..504b52cd04 100644 --- a/pkg/util/shellutil/shellquote.go +++ b/pkg/util/shellutil/shellquote.go @@ -13,7 +13,7 @@ const ( ) var ( - safePattern = regexp.MustCompile(`^[a-zA-Z0-9_/.-]+$`) + safePattern = regexp.MustCompile(`^[a-zA-Z0-9_@:,+=/.-]+$`) envVarNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) ) @@ -41,8 +41,6 @@ func HardQuote(s string) string { switch s[i] { case '"', '\\', '$', '`': buf = append(buf, '\\', s[i]) - case '\n': - buf = append(buf, '\\', '\n') default: buf = append(buf, s[i]) } diff --git a/pkg/util/shellutil/tokenswap.go b/pkg/util/shellutil/tokenswap.go index 339ec6445f..302bfa59b8 100644 --- a/pkg/util/shellutil/tokenswap.go +++ b/pkg/util/shellutil/tokenswap.go @@ -18,7 +18,6 @@ var tokenMapLock = &sync.Mutex{} type TokenSwapEntry struct { Token string `json:"token"` - SockName string `json:"sockname,omitempty"` RpcContext *wshrpc.RpcContext `json:"rpccontext,omitempty"` Env map[string]string `json:"env,omitempty"` ScriptText string `json:"scripttext,omitempty"` @@ -27,7 +26,6 @@ type TokenSwapEntry struct { type UnpackedTokenType struct { Token string `json:"token"` // uuid - SockName string `json:"sockname,omitempty"` RpcContext *wshrpc.RpcContext `json:"rpccontext,omitempty"` } @@ -57,7 +55,6 @@ func UnpackSwapToken(token string) (*UnpackedTokenType, error) { func (t *TokenSwapEntry) PackForClient() (string, error) { unpackedToken := &UnpackedTokenType{ Token: t.Token, - SockName: t.SockName, RpcContext: t.RpcContext, } return unpackedToken.Pack() diff --git a/pkg/util/sigutil/sigusr1_notwindows.go b/pkg/util/sigutil/sigusr1_notwindows.go index d006c7f322..df0cb8babe 100644 --- a/pkg/util/sigutil/sigusr1_notwindows.go +++ b/pkg/util/sigutil/sigusr1_notwindows.go @@ -6,6 +6,7 @@ package sigutil import ( + "log" "os" "os/signal" "syscall" @@ -14,6 +15,8 @@ import ( "github.com/wavetermdev/waveterm/pkg/util/utilfn" ) +const DumpFilePath = "/tmp/waveterm-usr1-dump.log" + func InstallSIGUSR1Handler() { sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGUSR1) @@ -22,7 +25,13 @@ func InstallSIGUSR1Handler() { panichandler.PanicHandler("InstallSIGUSR1Handler", recover()) }() for range sigCh { - utilfn.DumpGoRoutineStacks() + file, err := os.Create(DumpFilePath) + if err != nil { + log.Printf("error creating dump file %q: %v", DumpFilePath, err) + continue + } + utilfn.DumpGoRoutineStacks(file) + file.Close() } }() } diff --git a/pkg/util/utilfn/utilfn.go b/pkg/util/utilfn/utilfn.go index 2b43f567ba..e988297692 100644 --- a/pkg/util/utilfn/utilfn.go +++ b/pkg/util/utilfn/utilfn.go @@ -1016,10 +1016,10 @@ func HasBinaryData(data []byte) bool { return false } -func DumpGoRoutineStacks() { +func DumpGoRoutineStacks(w io.Writer) { buf := make([]byte, 1<<20) n := runtime.Stack(buf, true) - os.Stdout.Write(buf[:n]) + w.Write(buf[:n]) } func ConvertToWallClockPT(t time.Time) time.Time { diff --git a/pkg/waveapp/waveapp.go b/pkg/waveapp/waveapp.go index 4b7ff00e28..339e09403b 100644 --- a/pkg/waveapp/waveapp.go +++ b/pkg/waveapp/waveapp.go @@ -181,11 +181,11 @@ func (client *Client) Connect() error { return fmt.Errorf("error setting up domain socket rpc client: %v", err) } client.RpcClient = rpcClient - authRtn, err := wshclient.AuthenticateCommand(client.RpcClient, jwtToken, nil) + _, err = wshclient.AuthenticateCommand(client.RpcClient, jwtToken, &wshrpc.RpcOpts{Route: wshutil.ControlRoute}) if err != nil { return fmt.Errorf("error authenticating rpc connection: %v", err) } - client.RouteId = authRtn.RouteId + client.RouteId = rpcCtx.RouteId return nil } diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index 2f37d5ff4e..bdaff9b137 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -59,7 +59,6 @@ const WaveLockFile = "wave.lock" const DomainSocketBaseName = "wave.sock" const RemoteDomainSocketBaseName = "wave-remote.sock" const WaveDBDir = "db" -const JwtSecret = "waveterm" // TODO generate and store this const ConfigDir = "config" const RemoteWaveHomeDirName = ".waveterm" const RemoteWshBinDirName = "bin" diff --git a/pkg/wavejwt/wavejwt.go b/pkg/wavejwt/wavejwt.go new file mode 100644 index 0000000000..45a621a9a3 --- /dev/null +++ b/pkg/wavejwt/wavejwt.go @@ -0,0 +1,140 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wavejwt + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "fmt" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + IssuerWaveTerm = "waveterm" +) + +var ( + globalLock sync.Mutex + publicKey ed25519.PublicKey + privateKey ed25519.PrivateKey +) + +type WaveJwtClaims struct { + jwt.RegisteredClaims + Sock string `json:"sock,omitempty"` + RouteId string `json:"routeid,omitempty"` + BlockId string `json:"blockid,omitempty"` + Conn string `json:"conn,omitempty"` + Router bool `json:"router,omitempty"` +} + +type KeyPair struct { + PublicKey []byte + PrivateKey []byte +} + +func GenerateKeyPair() (*KeyPair, error) { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate key pair: %w", err) + } + + return &KeyPair{ + PublicKey: pubKey, + PrivateKey: privKey, + }, nil +} + +func SetPublicKey(keyData []byte) error { + if len(keyData) != ed25519.PublicKeySize { + return fmt.Errorf("invalid public key size: expected %d, got %d", ed25519.PublicKeySize, len(keyData)) + } + globalLock.Lock() + defer globalLock.Unlock() + publicKey = ed25519.PublicKey(keyData) + return nil +} + +func GetPublicKey() []byte { + globalLock.Lock() + defer globalLock.Unlock() + return publicKey +} + +func GetPublicKeyBase64() string { + pubKey := GetPublicKey() + if len(pubKey) == 0 { + return "" + } + return base64.StdEncoding.EncodeToString(pubKey) +} + +func SetPrivateKey(keyData []byte) error { + if len(keyData) != ed25519.PrivateKeySize { + return fmt.Errorf("invalid private key size: expected %d, got %d", ed25519.PrivateKeySize, len(keyData)) + } + globalLock.Lock() + defer globalLock.Unlock() + privateKey = ed25519.PrivateKey(keyData) + return nil +} + +func ValidateAndExtract(tokenStr string) (*WaveJwtClaims, error) { + globalLock.Lock() + pubKey := publicKey + globalLock.Unlock() + + if pubKey == nil { + return nil, fmt.Errorf("public key not set") + } + + token, err := jwt.ParseWithClaims(tokenStr, &WaveJwtClaims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return pubKey, nil + }) + if err != nil { + return nil, fmt.Errorf("failed to parse token: %w", err) + } + + claims, ok := token.Claims.(*WaveJwtClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("invalid token") + } + + return claims, nil +} + +func Sign(claims *WaveJwtClaims) (string, error) { + globalLock.Lock() + privKey := privateKey + globalLock.Unlock() + + if privKey == nil { + return "", fmt.Errorf("private key not set") + } + + if claims.IssuedAt == nil { + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + } + if claims.Issuer == "" { + claims.Issuer = IssuerWaveTerm + } + if claims.ExpiresAt == nil { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)) + } + + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + tokenStr, err := token.SignedString(privKey) + if err != nil { + return "", fmt.Errorf("error signing token: %w", err) + } + + return tokenStr, nil +} diff --git a/pkg/waveobj/wtype.go b/pkg/waveobj/wtype.go index fbd63876b3..89165bc7b0 100644 --- a/pkg/waveobj/wtype.go +++ b/pkg/waveobj/wtype.go @@ -28,6 +28,7 @@ const ( OType_Tab = "tab" OType_LayoutState = "layout" OType_Block = "block" + OType_MainServer = "mainserver" OType_Temp = "temp" OType_Builder = "builder" // not persisted to DB ) @@ -39,6 +40,7 @@ var ValidOTypes = map[string]bool{ OType_Tab: true, OType_LayoutState: true, OType_Block: true, + OType_MainServer: true, OType_Temp: true, OType_Builder: true, } @@ -293,6 +295,18 @@ func (*Block) GetOType() string { return OType_Block } +type MainServer struct { + OID string `json:"oid"` + Version int `json:"version"` + Meta MetaMapType `json:"meta"` + JwtPrivateKey string `json:"jwtprivatekey"` // base64 + JwtPublicKey string `json:"jwtpublickey"` // base64 +} + +func (*MainServer) GetOType() string { + return OType_MainServer +} + func AllWaveObjTypes() []reflect.Type { return []reflect.Type{ reflect.TypeOf(&Client{}), @@ -301,6 +315,7 @@ func AllWaveObjTypes() []reflect.Type { reflect.TypeOf(&Tab{}), reflect.TypeOf(&Block{}), reflect.TypeOf(&LayoutState{}), + reflect.TypeOf(&MainServer{}), } } diff --git a/pkg/wcloud/wcloud.go b/pkg/wcloud/wcloud.go index 6a13ee92b9..3b96df838b 100644 --- a/pkg/wcloud/wcloud.go +++ b/pkg/wcloud/wcloud.go @@ -140,9 +140,11 @@ func makeAnonPostReq(ctx context.Context, apiUrl string, data interface{}) (*htt return req, nil } -func doRequest(req *http.Request, outputObj interface{}) (*http.Response, error) { +func doRequest(req *http.Request, outputObj interface{}, verbose bool) (*http.Response, error) { apiUrl := req.Header.Get("X-PromptAPIUrl") - log.Printf("[wcloud] sending request %s %v\n", req.Method, req.URL) + if verbose { + log.Printf("[wcloud] sending request %s %v\n", req.Method, req.URL) + } resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("error contacting wcloud %q service: %v", apiUrl, err) @@ -192,7 +194,7 @@ func sendTEventsBatch(clientId string) (bool, int, error) { return true, 0, err } startTime := time.Now() - _, err = doRequest(req, nil) + _, err = doRequest(req, nil, true) latency := time.Since(startTime) log.Printf("[wcloud] sent %d tevents (latency: %v)\n", len(events), latency) if err != nil { @@ -269,7 +271,7 @@ func sendTelemetry(clientId string) error { if err != nil { return err } - _, err = doRequest(req, nil) + _, err = doRequest(req, nil, true) if err != nil { return err } @@ -285,7 +287,7 @@ func SendNoTelemetryUpdate(ctx context.Context, clientId string, noTelemetryVal if err != nil { return err } - _, err = doRequest(req, nil) + _, err = doRequest(req, nil, true) if err != nil { return err } @@ -341,7 +343,7 @@ func SendDiagnosticPing(ctx context.Context, clientId string, usageTelemetry boo if err != nil { return err } - _, err = doRequest(req, nil) + _, err = doRequest(req, nil, false) if err != nil { return err } diff --git a/pkg/wcore/wcore.go b/pkg/wcore/wcore.go index de6cba9f50..d8603f5caf 100644 --- a/pkg/wcore/wcore.go +++ b/pkg/wcore/wcore.go @@ -6,6 +6,10 @@ package wcore import ( "context" + "crypto/ed25519" + "crypto/x509" + "encoding/base64" + "encoding/pem" "fmt" "log" "strings" @@ -14,6 +18,7 @@ import ( "github.com/google/uuid" "github.com/wavetermdev/waveterm/pkg/panichandler" + "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wcloud" "github.com/wavetermdev/waveterm/pkg/wps" @@ -154,3 +159,70 @@ func GoSendNoTelemetryUpdate(telemetryEnabled bool) { } }() } + +func InitMainServer() error { + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + + mainServer, err := wstore.DBGetSingleton[*waveobj.MainServer](ctx) + if err == wstore.ErrNotFound { + mainServer = &waveobj.MainServer{ + OID: uuid.NewString(), + } + err = wstore.DBInsert(ctx, mainServer) + if err != nil { + return fmt.Errorf("error inserting mainserver: %w", err) + } + } else if err != nil { + return fmt.Errorf("error getting mainserver: %w", err) + } + + needsUpdate := false + if mainServer.JwtPrivateKey == "" || mainServer.JwtPublicKey == "" { + keyPair, err := wavejwt.GenerateKeyPair() + if err != nil { + return fmt.Errorf("error generating jwt keypair: %w", err) + } + mainServer.JwtPrivateKey = base64.StdEncoding.EncodeToString(keyPair.PrivateKey) + mainServer.JwtPublicKey = base64.StdEncoding.EncodeToString(keyPair.PublicKey) + needsUpdate = true + } + + if needsUpdate { + err = wstore.DBUpdate(ctx, mainServer) + if err != nil { + return fmt.Errorf("error updating mainserver: %w", err) + } + } + + privateKeyBytes, err := base64.StdEncoding.DecodeString(mainServer.JwtPrivateKey) + if err != nil { + return fmt.Errorf("error decoding jwt private key: %w", err) + } + publicKeyBytes, err := base64.StdEncoding.DecodeString(mainServer.JwtPublicKey) + if err != nil { + return fmt.Errorf("error decoding jwt public key: %w", err) + } + + err = wavejwt.SetPrivateKey(privateKeyBytes) + if err != nil { + return fmt.Errorf("error setting jwt private key: %w", err) + } + err = wavejwt.SetPublicKey(publicKeyBytes) + if err != nil { + return fmt.Errorf("error setting jwt public key: %w", err) + } + + pubKeyDer, err := x509.MarshalPKIXPublicKey(ed25519.PublicKey(publicKeyBytes)) + if err != nil { + log.Printf("warning: could not marshal public key for logging: %v", err) + } else { + pubKeyPem := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubKeyDer, + }) + log.Printf("JWT Public Key:\n%s", string(pubKeyPem)) + } + + return nil +} diff --git a/pkg/web/ws.go b/pkg/web/ws.go index 549b89a396..0e6f0b0f9b 100644 --- a/pkg/web/ws.go +++ b/pkg/web/ws.go @@ -16,6 +16,7 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/wavetermdev/waveterm/pkg/authkey" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/eventbus" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/web/webcmd" @@ -31,8 +32,13 @@ const wsMaxMessageSize = 10 * 1024 * 1024 const DefaultCommandTimeout = 2 * time.Second +type StableConnInfo struct { + ConnId string + LinkId baseds.LinkId +} + var GlobalLock = &sync.Mutex{} -var RouteToConnMap = map[string]string{} // routeid => connid +var RouteToConnMap = map[string]*StableConnInfo{} // stableid => StableConnInfo func RunWebSocketServer(listener net.Listener) { gr := mux.NewRouter() @@ -79,7 +85,7 @@ func getStringFromMap(jmsg map[string]any, key string) string { return "" } -func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan []byte) { +func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan baseds.RpcInputChType) { var rtnErr error var cmdType string defer func() { @@ -119,7 +125,7 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] log.Printf("[websocket] error marshalling rpc message: %v\n", err) return } - rpcInputCh <- msgBytes + rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes} case *webcmd.BlockInputWSCommand: data := wshrpc.CommandBlockInputData{ @@ -136,7 +142,7 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] log.Printf("[websocket] error marshalling rpc message: %v\n", err) return } - rpcInputCh <- msgBytes + rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes} case *webcmd.WSRpcCommand: rpcMsg := cmd.Message @@ -151,11 +157,11 @@ func processWSCommand(jmsg map[string]any, outputCh chan any, rpcInputCh chan [] // this really should never fail since we just unmarshalled this value return } - rpcInputCh <- msgBytes + rpcInputCh <- baseds.RpcInputChType{MsgBytes: msgBytes} } } -func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan []byte) { +func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan baseds.RpcInputChType) { wsCommand := getStringFromMap(jmsg, "wscommand") if wsCommand == "" { return @@ -163,7 +169,7 @@ func processMessage(jmsg map[string]any, outputCh chan any, rpcInputCh chan []by processWSCommand(jmsg, outputCh, rpcInputCh) } -func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInputCh chan []byte, routeId string) { +func ReadLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, rpcInputCh chan baseds.RpcInputChType, routeId string) { readWait := wsReadWaitTimeout conn.SetReadLimit(wsMaxMessageSize) conn.SetReadDeadline(time.Now().Add(readWait)) @@ -251,35 +257,41 @@ func WriteLoop(conn *websocket.Conn, outputCh chan any, closeCh chan any, routeI } } -func registerConn(wsConnId string, routeId string, wproxy *wshutil.WshRpcProxy) { +func registerConn(wsConnId string, stableId string, wproxy *wshutil.WshRpcProxy) { GlobalLock.Lock() defer GlobalLock.Unlock() - curConnId := RouteToConnMap[routeId] - if curConnId != "" { - log.Printf("[websocket] warning: replacing existing connection for route %q\n", routeId) - wshutil.DefaultRouter.UnregisterRoute(routeId) + curConnInfo := RouteToConnMap[stableId] + if curConnInfo != nil { + log.Printf("[websocket] warning: replacing existing connection for stableid %q\n", stableId) + if curConnInfo.LinkId != baseds.NoLinkId { + wshutil.DefaultRouter.UnregisterLink(curConnInfo.LinkId) + } + } + linkId := wshutil.DefaultRouter.RegisterTrustedRouter(wproxy) + RouteToConnMap[stableId] = &StableConnInfo{ + ConnId: wsConnId, + LinkId: linkId, } - RouteToConnMap[routeId] = wsConnId - wshutil.DefaultRouter.RegisterRoute(routeId, wproxy, true) } -func unregisterConn(wsConnId string, routeId string) { +func unregisterConn(wsConnId string, stableId string) { GlobalLock.Lock() defer GlobalLock.Unlock() - curConnId := RouteToConnMap[routeId] - if curConnId != wsConnId { - // only unregister if we are the current connection (otherwise we were already removed) - log.Printf("[websocket] warning: trying to unregister connection %q for route %q but it is not the current connection (ignoring)\n", wsConnId, routeId) + curConnInfo := RouteToConnMap[stableId] + if curConnInfo == nil || curConnInfo.ConnId != wsConnId { + log.Printf("[websocket] warning: trying to unregister connection %q for stableid %q but it is not the current connection (ignoring)\n", wsConnId, stableId) return } - delete(RouteToConnMap, routeId) - wshutil.DefaultRouter.UnregisterRoute(routeId) + delete(RouteToConnMap, stableId) + if curConnInfo.LinkId != baseds.NoLinkId { + wshutil.DefaultRouter.UnregisterLink(curConnInfo.LinkId) + } } func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { - routeId := r.URL.Query().Get("routeid") - if routeId == "" { - return fmt.Errorf("routeid is required") + stableId := r.URL.Query().Get("stableid") + if stableId == "" { + return fmt.Errorf("stableid is required") } err := authkey.ValidateIncomingRequest(r) if err != nil { @@ -296,13 +308,13 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { wsConnId := uuid.New().String() outputCh := make(chan any, 100) closeCh := make(chan any) - log.Printf("[websocket] new connection: connid:%s routeid:%s\n", wsConnId, routeId) - eventbus.RegisterWSChannel(wsConnId, routeId, outputCh) + log.Printf("[websocket] new connection: connid:%s stableid:%s\n", wsConnId, stableId) + eventbus.RegisterWSChannel(wsConnId, stableId, outputCh) defer eventbus.UnregisterWSChannel(wsConnId) - wproxy := wshutil.MakeRpcProxy() // we create a wshproxy to handle rpc messages to/from the window + wproxy := wshutil.MakeRpcProxy(fmt.Sprintf("ws:%s", stableId)) defer close(wproxy.ToRemoteCh) - registerConn(wsConnId, routeId, wproxy) - defer unregisterConn(wsConnId, routeId) + registerConn(wsConnId, stableId, wproxy) + defer unregisterConn(wsConnId, stableId) wg := &sync.WaitGroup{} wg.Add(2) go func() { @@ -323,17 +335,15 @@ func HandleWsInternal(w http.ResponseWriter, r *http.Request) error { defer func() { panichandler.PanicHandler("HandleWsInternal:ReadLoop", recover()) }() - // read loop defer wg.Done() - ReadLoop(conn, outputCh, closeCh, wproxy.FromRemoteCh, routeId) + ReadLoop(conn, outputCh, closeCh, wproxy.FromRemoteCh, stableId) }() go func() { defer func() { panichandler.PanicHandler("HandleWsInternal:WriteLoop", recover()) }() - // write loop defer wg.Done() - WriteLoop(conn, outputCh, closeCh, routeId) + WriteLoop(conn, outputCh, closeCh, stableId) }() wg.Wait() close(wproxy.FromRemoteCh) diff --git a/pkg/wshrpc/wshclient/barerpcclient.go b/pkg/wshrpc/wshclient/barerpcclient.go index f05b75095c..62d1f27ea7 100644 --- a/pkg/wshrpc/wshclient/barerpcclient.go +++ b/pkg/wshrpc/wshclient/barerpcclient.go @@ -29,10 +29,8 @@ const BareClientRoute = "bare" func GetBareRpcClient() *wshutil.WshRpc { waveSrvClient_Once.Do(func() { - inputCh := make(chan []byte, DefaultInputChSize) - outputCh := make(chan []byte, DefaultOutputChSize) - waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl, "bare-client") - wshutil.DefaultRouter.RegisterRoute(BareClientRoute, waveSrvClient_Singleton, true) + waveSrvClient_Singleton = wshutil.MakeWshRpc(wshrpc.RpcContext{}, &WshServerImpl, "bare-client") + wshutil.DefaultRouter.RegisterTrustedLeaf(waveSrvClient_Singleton, BareClientRoute) wps.Broker.SetClient(wshutil.DefaultRouter) }) return waveSrvClient_Singleton diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index e9f518ab02..8a5d9317c3 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -41,6 +41,12 @@ func AuthenticateTokenCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticate return resp, err } +// command "authenticatetokenverify", wshserver.AuthenticateTokenVerifyCommand +func AuthenticateTokenVerifyCommand(w *wshutil.WshRpc, data wshrpc.CommandAuthenticateTokenData, opts *wshrpc.RpcOpts) (wshrpc.CommandAuthenticateRtnData, error) { + resp, err := sendRpcRequestCallHelper[wshrpc.CommandAuthenticateRtnData](w, "authenticatetokenverify", data, opts) + return resp, err +} + // command "blockinfo", wshserver.BlockInfoCommand func BlockInfoCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) (*wshrpc.BlockInfoData, error) { resp, err := sendRpcRequestCallHelper[*wshrpc.BlockInfoData](w, "blockinfo", data, opts) @@ -374,6 +380,12 @@ func GetFullConfigCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (wconfig.Full return resp, err } +// command "getjwtpublickey", wshserver.GetJwtPublicKeyCommand +func GetJwtPublicKeyCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) (string, error) { + resp, err := sendRpcRequestCallHelper[string](w, "getjwtpublickey", nil, opts) + return resp, err +} + // command "getmeta", wshserver.GetMetaCommand func GetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandGetMetaData, opts *wshrpc.RpcOpts) (waveobj.MetaMapType, error) { resp, err := sendRpcRequestCallHelper[waveobj.MetaMapType](w, "getmeta", data, opts) @@ -646,6 +658,12 @@ func SetMetaCommand(w *wshutil.WshRpc, data wshrpc.CommandSetMetaData, opts *wsh return err } +// command "setpeerinfo", wshserver.SetPeerInfoCommand +func SetPeerInfoCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { + _, err := sendRpcRequestCallHelper[any](w, "setpeerinfo", data, opts) + return err +} + // command "setrtinfo", wshserver.SetRTInfoCommand func SetRTInfoCommand(w *wshutil.WshRpc, data wshrpc.CommandSetRTInfoData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "setrtinfo", data, opts) @@ -664,12 +682,6 @@ func SetVarCommand(w *wshutil.WshRpc, data wshrpc.CommandVarData, opts *wshrpc.R return err } -// command "setview", wshserver.SetViewCommand -func SetViewCommand(w *wshutil.WshRpc, data wshrpc.CommandBlockSetViewData, opts *wshrpc.RpcOpts) error { - _, err := sendRpcRequestCallHelper[any](w, "setview", data, opts) - return err -} - // command "startbuilder", wshserver.StartBuilderCommand func StartBuilderCommand(w *wshutil.WshRpc, data wshrpc.CommandStartBuilderData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "startbuilder", data, opts) diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 14ffb14779..239bbbae26 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -8,9 +8,7 @@ import ( "bytes" "context" "encoding/json" - "log" "os" - "reflect" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/ijson" @@ -52,16 +50,20 @@ const ( // TODO generate these constants from the interface const ( - Command_Authenticate = "authenticate" // special - Command_AuthenticateToken = "authenticatetoken" // special - Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only) - Command_RouteAnnounce = "routeannounce" // special (for routing) - Command_RouteUnannounce = "routeunannounce" // special (for routing) + Command_Authenticate = "authenticate" // $control + Command_AuthenticateToken = "authenticatetoken" // $control + Command_AuthenticateTokenVerify = "authenticatetokenverify" // $control:root (internal, for token validation only) + Command_Dispose = "dispose" // $control (disposes of the route, for multiproxy only) + Command_RouteAnnounce = "routeannounce" // $control (for routing) + Command_RouteUnannounce = "routeunannounce" // $control (for routing) + Command_SetPeerInfo = "setpeerinfo" // $control (sets peer info on proxy) + Command_ControlMessage = "controlmessage" // $control + Command_Ping = "ping" // $control + + Command_GetJwtPublicKey = "getjwtpublickey" Command_Message = "message" - Command_GetMeta = "getmeta" Command_SetMeta = "setmeta" - Command_SetView = "setview" Command_ControllerInput = "controllerinput" Command_ControllerRestart = "controllerrestart" Command_ControllerStop = "controllerstop" @@ -198,14 +200,16 @@ type RespOrErrorUnion[T any] struct { type WshRpcInterface interface { AuthenticateCommand(ctx context.Context, data string) (CommandAuthenticateRtnData, error) AuthenticateTokenCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) + AuthenticateTokenVerifyCommand(ctx context.Context, data CommandAuthenticateTokenData) (CommandAuthenticateRtnData, error) // (special) validates token without binding, root router only DisposeCommand(ctx context.Context, data CommandDisposeData) error RouteAnnounceCommand(ctx context.Context) error // (special) announces a new route to the main router RouteUnannounceCommand(ctx context.Context) error // (special) unannounces a route to the main router + SetPeerInfoCommand(ctx context.Context, peerInfo string) error + GetJwtPublicKeyCommand(ctx context.Context) (string, error) // (special) gets the public JWT signing key MessageCommand(ctx context.Context, data CommandMessageData) error GetMetaCommand(ctx context.Context, data CommandGetMetaData) (waveobj.MetaMapType, error) SetMetaCommand(ctx context.Context, data CommandSetMetaData) error - SetViewCommand(ctx context.Context, data CommandBlockSetViewData) error ControllerInputCommand(ctx context.Context, data CommandBlockInputData) error ControllerStopCommand(ctx context.Context, blockId string) error ControllerResyncCommand(ctx context.Context, data CommandControllerResyncData) error @@ -373,59 +377,22 @@ type RpcOpts struct { NoResponse bool `json:"noresponse,omitempty"` Route string `json:"route,omitempty"` - StreamCancelFn func() `json:"-"` // this is an *output* parameter, set by the handler + StreamCancelFn func(context.Context) error `json:"-"` // this is an *output* parameter, set by the handler } -const ( - ClientType_ConnServer = "connserver" - ClientType_BlockController = "blockcontroller" -) - type RpcContext struct { - ClientType string `json:"ctype,omitempty"` - BlockId string `json:"blockid,omitempty"` - TabId string `json:"tabid,omitempty"` - Conn string `json:"conn,omitempty"` -} - -func HackRpcContextIntoData(dataPtr any, rpcContext RpcContext) { - dataVal := reflect.ValueOf(dataPtr).Elem() - if dataVal.Kind() != reflect.Struct { - return - } - dataType := dataVal.Type() - for i := 0; i < dataVal.NumField(); i++ { - field := dataVal.Field(i) - if !field.IsZero() { - continue - } - fieldType := dataType.Field(i) - tag := fieldType.Tag.Get("wshcontext") - if tag == "" { - continue - } - switch tag { - case "BlockId": - field.SetString(rpcContext.BlockId) - case "TabId": - field.SetString(rpcContext.TabId) - case "BlockORef": - if rpcContext.BlockId != "" { - field.Set(reflect.ValueOf(waveobj.MakeORef(waveobj.OType_Block, rpcContext.BlockId))) - } - default: - log.Printf("invalid wshcontext tag: %q in type(%T)", tag, dataPtr) - } - } + SockName string `json:"sockname,omitempty"` // the domain socket name + RouteId string `json:"routeid"` // the routeid from the jwt + BlockId string `json:"blockid,omitempty"` // blockid for this rpc + Conn string `json:"conn,omitempty"` // the conn name + IsRouter bool `json:"isrouter,omitempty"` // if this is for a sub-router } type CommandAuthenticateRtnData struct { - RouteId string `json:"routeid"` - AuthToken string `json:"authtoken,omitempty"` - // these fields are only set when doing a token swap Env map[string]string `json:"env,omitempty"` InitScriptText string `json:"initscripttext,omitempty"` + RpcContext *RpcContext `json:"rpccontext,omitempty"` } type CommandAuthenticateTokenData struct { @@ -438,21 +405,20 @@ type CommandDisposeData struct { } type CommandMessageData struct { - ORef waveobj.ORef `json:"oref" wshcontext:"BlockORef"` - Message string `json:"message"` + Message string `json:"message"` } type CommandGetMetaData struct { - ORef waveobj.ORef `json:"oref" wshcontext:"BlockORef"` + ORef waveobj.ORef `json:"oref"` } type CommandSetMetaData struct { - ORef waveobj.ORef `json:"oref" wshcontext:"BlockORef"` + ORef waveobj.ORef `json:"oref"` Meta waveobj.MetaMapType `json:"meta"` } type CommandResolveIdsData struct { - BlockId string `json:"blockid" wshcontext:"BlockId"` + BlockId string `json:"blockid"` Ids []string `json:"ids"` } @@ -461,7 +427,7 @@ type CommandResolveIdsRtnData struct { } type CommandCreateBlockData struct { - TabId string `json:"tabid" wshcontext:"TabId"` + TabId string `json:"tabid"` BlockDef *waveobj.BlockDef `json:"blockdef"` RtOpts *waveobj.RuntimeOpts `json:"rtopts,omitempty"` Magnified bool `json:"magnified,omitempty"` @@ -476,15 +442,10 @@ type CommandCreateSubBlockData struct { BlockDef *waveobj.BlockDef `json:"blockdef"` } -type CommandBlockSetViewData struct { - BlockId string `json:"blockid" wshcontext:"BlockId"` - View string `json:"view"` -} - type CommandControllerResyncData struct { ForceRestart bool `json:"forcerestart,omitempty"` - TabId string `json:"tabid" wshcontext:"TabId"` - BlockId string `json:"blockid" wshcontext:"BlockId"` + TabId string `json:"tabid"` + BlockId string `json:"blockid"` RtOpts *waveobj.RuntimeOpts `json:"rtopts,omitempty"` } @@ -494,7 +455,7 @@ type CommandControllerAppendOutputData struct { } type CommandBlockInputData struct { - BlockId string `json:"blockid" wshcontext:"BlockId"` + BlockId string `json:"blockid"` InputData64 string `json:"inputdata64,omitempty"` SigName string `json:"signame,omitempty"` TermSize *waveobj.TermSize `json:"termsize,omitempty"` @@ -560,7 +521,7 @@ type FileCreateData struct { } type CommandAppendIJsonData struct { - ZoneId string `json:"zoneid" wshcontext:"BlockId"` + ZoneId string `json:"zoneid"` FileName string `json:"filename"` Data ijson.Command `json:"data"` } @@ -571,7 +532,7 @@ type CommandWaitForRouteData struct { } type CommandDeleteBlockData struct { - BlockId string `json:"blockid" wshcontext:"BlockId"` + BlockId string `json:"blockid"` } type CommandEventReadHistoryData struct { @@ -749,8 +710,8 @@ type WebSelectorOpts struct { type CommandWebSelectorData struct { WorkspaceId string `json:"workspaceid"` - BlockId string `json:"blockid" wshcontext:"BlockId"` - TabId string `json:"tabid" wshcontext:"TabId"` + BlockId string `json:"blockid"` + TabId string `json:"tabid"` Selector string `json:"selector"` Opts *WebSelectorOpts `json:"opts,omitempty"` } @@ -846,7 +807,7 @@ type CommandWaveAIGetToolDiffRtnData struct { } type CommandCaptureBlockScreenshotData struct { - BlockId string `json:"blockid" wshcontext:"BlockId"` + BlockId string `json:"blockid"` } type CommandVarData struct { @@ -867,7 +828,7 @@ type PathCommandData struct { PathType string `json:"pathtype"` Open bool `json:"open"` OpenExternal bool `json:"openexternal"` - TabId string `json:"tabid" wshcontext:"TabId"` + TabId string `json:"tabid"` } type ActivityDisplayType struct { diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index cc4fc1d7ea..0f23d1b654 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -48,6 +48,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/waveappstore" "github.com/wavetermdev/waveterm/pkg/waveapputil" "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wcloud" "github.com/wavetermdev/waveterm/pkg/wconfig" @@ -69,17 +70,8 @@ func (*WshServer) WshServerImpl() {} var WshServerImpl = WshServer{} -// TODO remove this after implementing in multiproxy, just for wsl -func (ws *WshServer) AuthenticateTokenCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { - entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token) - if entry == nil { - return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("invalid token") - } - rtn := wshrpc.CommandAuthenticateRtnData{ - Env: entry.Env, - InitScriptText: entry.ScriptText, - } - return rtn, nil +func (ws *WshServer) GetJwtPublicKeyCommand(ctx context.Context) (string, error) { + return wavejwt.GetPublicKeyBase64(), nil } func (ws *WshServer) TestCommand(ctx context.Context, data string) error { @@ -93,7 +85,7 @@ func (ws *WshServer) TestCommand(ctx context.Context, data string) error { // for testing func (ws *WshServer) MessageCommand(ctx context.Context, data wshrpc.CommandMessageData) error { - log.Printf("MESSAGE: %s | %q\n", data.ORef, data.Message) + log.Printf("MESSAGE: %s\n", data.Message) return nil } @@ -290,23 +282,6 @@ func (ws *WshServer) CreateSubBlockCommand(ctx context.Context, data wshrpc.Comm return blockRef, nil } -func (ws *WshServer) SetViewCommand(ctx context.Context, data wshrpc.CommandBlockSetViewData) error { - log.Printf("SETVIEW: %s | %q\n", data.BlockId, data.View) - ctx = waveobj.ContextWithUpdates(ctx) - block, err := wstore.DBGet[*waveobj.Block](ctx, data.BlockId) - if err != nil { - return fmt.Errorf("error getting block: %w", err) - } - block.Meta[waveobj.MetaKey_View] = data.View - err = wstore.DBUpdate(ctx, block) - if err != nil { - return fmt.Errorf("error updating block: %w", err) - } - updates := waveobj.ContextGetUpdatesRtn(ctx) - wps.Broker.SendUpdateEvents(updates) - return nil -} - func (ws *WshServer) ControllerStopCommand(ctx context.Context, blockId string) error { blockcontroller.StopBlockController(blockId) return nil @@ -495,6 +470,9 @@ func (ws *WshServer) WriteTempFileCommand(ctx context.Context, data wshrpc.Comma } func (ws *WshServer) DeleteSubBlockCommand(ctx context.Context, data wshrpc.CommandDeleteBlockData) error { + if data.BlockId == "" { + return fmt.Errorf("blockid is required") + } err := wcore.DeleteBlock(ctx, data.BlockId, false) if err != nil { return fmt.Errorf("error deleting block: %w", err) @@ -503,6 +481,9 @@ func (ws *WshServer) DeleteSubBlockCommand(ctx context.Context, data wshrpc.Comm } func (ws *WshServer) DeleteBlockCommand(ctx context.Context, data wshrpc.CommandDeleteBlockData) error { + if data.BlockId == "" { + return fmt.Errorf("blockid is required") + } ctx = waveobj.ContextWithUpdates(ctx) tabId, err := wstore.DBFindTabForBlockId(ctx, data.BlockId) if err != nil { @@ -1364,10 +1345,15 @@ func (ws *WshServer) PathCommand(ctx context.Context, data wshrpc.PathCommandDat } if openInternal { - _, err := ws.CreateBlockCommand(ctx, wshrpc.CommandCreateBlockData{BlockDef: &waveobj.BlockDef{Meta: map[string]any{ - waveobj.MetaKey_View: "preview", - waveobj.MetaKey_File: path, - }}, Ephemeral: true, Focused: true, TabId: data.TabId}) + _, err := ws.CreateBlockCommand(ctx, wshrpc.CommandCreateBlockData{ + TabId: data.TabId, + BlockDef: &waveobj.BlockDef{Meta: map[string]any{ + waveobj.MetaKey_View: "preview", + waveobj.MetaKey_File: path, + }}, + Ephemeral: true, + Focused: true, + }) if err != nil { return path, fmt.Errorf("error opening path: %w", err) diff --git a/pkg/wshrpc/wshserver/wshserverutil.go b/pkg/wshrpc/wshserver/wshserverutil.go index 252d358914..c4e3f3543d 100644 --- a/pkg/wshrpc/wshserver/wshserverutil.go +++ b/pkg/wshrpc/wshserver/wshserverutil.go @@ -21,9 +21,7 @@ var waveSrvClient_Once = &sync.Once{} // returns the wavesrv main rpc client singleton func GetMainRpcClient() *wshutil.WshRpc { waveSrvClient_Once.Do(func() { - inputCh := make(chan []byte, DefaultInputChSize) - outputCh := make(chan []byte, DefaultOutputChSize) - waveSrvClient_Singleton = wshutil.MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, &WshServerImpl, "main-client") + waveSrvClient_Singleton = wshutil.MakeWshRpc(wshrpc.RpcContext{}, &WshServerImpl, "main-client") }) return waveSrvClient_Singleton } diff --git a/pkg/wshutil/wshadapter.go b/pkg/wshutil/wshadapter.go index 22667dbfe2..a91bb4568f 100644 --- a/pkg/wshutil/wshadapter.go +++ b/pkg/wshutil/wshadapter.go @@ -53,7 +53,7 @@ func noImplHandler(handler *RpcResponseHandler) bool { return true } -func recodeCommandData(command string, data any, rpcCtx *wshrpc.RpcContext) (any, error) { +func recodeCommandData(command string, data any) (any, error) { // only applies to initial command packet if command == "" { return data, nil @@ -71,9 +71,6 @@ func recodeCommandData(command string, data any, rpcCtx *wshrpc.RpcContext) (any if err != nil { return data, fmt.Errorf("error re-marshalling command data: %w", err) } - if rpcCtx != nil { - wshrpc.HackRpcContextIntoData(commandDataPtr, *rpcCtx) - } } return reflect.ValueOf(commandDataPtr).Elem().Interface(), nil } @@ -107,8 +104,7 @@ func serverImplAdapter(impl any) func(*RpcResponseHandler) bool { var callParams []reflect.Value callParams = append(callParams, reflect.ValueOf(handler.Context())) if methodDecl.CommandDataType != nil { - rpcCtx := handler.GetRpcContext() - cmdData, err := recodeCommandData(cmd, handler.GetCommandRawData(), &rpcCtx) + cmdData, err := recodeCommandData(cmd, handler.GetCommandRawData()) if err != nil { handler.SendResponseError(err) return true diff --git a/pkg/wshutil/wshcmdreader.go b/pkg/wshutil/wshcmdreader.go index b4cec125b2..60007fe0e5 100644 --- a/pkg/wshutil/wshcmdreader.go +++ b/pkg/wshutil/wshcmdreader.go @@ -8,6 +8,8 @@ import ( "fmt" "io" "sync" + + "github.com/wavetermdev/waveterm/pkg/baseds" ) const ( @@ -25,13 +27,13 @@ type PtyBuffer struct { EscSeqBuf []byte OSCPrefix string InputReader io.Reader - MessageCh chan []byte + MessageCh chan baseds.RpcInputChType AtEOF bool Err error } // closes messageCh when input is closed (or error) -func MakePtyBuffer(oscPrefix string, input io.Reader, messageCh chan []byte) *PtyBuffer { +func MakePtyBuffer(oscPrefix string, input io.Reader, messageCh chan baseds.RpcInputChType) *PtyBuffer { if len(oscPrefix) != WaveOSCPrefixLen { panic(fmt.Sprintf("invalid OSC prefix length: %d", len(oscPrefix))) } @@ -64,7 +66,7 @@ func (b *PtyBuffer) setEOF() { } func (b *PtyBuffer) processWaveEscSeq(escSeq []byte) { - b.MessageCh <- escSeq + b.MessageCh <- baseds.RpcInputChType{MsgBytes: escSeq} } func (b *PtyBuffer) run() { diff --git a/pkg/wshutil/wshmultiproxy.go b/pkg/wshutil/wshmultiproxy.go deleted file mode 100644 index b9eb8c2537..0000000000 --- a/pkg/wshutil/wshmultiproxy.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2025, Command Line Inc. -// SPDX-License-Identifier: Apache-2.0 - -package wshutil - -import ( - "encoding/json" - "fmt" - "sync" - - "github.com/google/uuid" - "github.com/wavetermdev/waveterm/pkg/panichandler" - "github.com/wavetermdev/waveterm/pkg/wshrpc" -) - -type multiProxyRouteInfo struct { - RouteId string - AuthToken string - Proxy *WshRpcProxy - RpcContext *wshrpc.RpcContext -} - -// handles messages from multiple unauthenitcated clients -type WshRpcMultiProxy struct { - Lock *sync.Mutex - RouteInfo map[string]*multiProxyRouteInfo // authtoken to info - ToRemoteCh chan []byte - FromRemoteRawCh chan []byte // raw message from the remote -} - -func MakeRpcMultiProxy() *WshRpcMultiProxy { - return &WshRpcMultiProxy{ - Lock: &sync.Mutex{}, - RouteInfo: make(map[string]*multiProxyRouteInfo), - ToRemoteCh: make(chan []byte, DefaultInputChSize), - FromRemoteRawCh: make(chan []byte, DefaultOutputChSize), - } -} - -func (p *WshRpcMultiProxy) DisposeRoutes() { - p.Lock.Lock() - defer p.Lock.Unlock() - for authToken, routeInfo := range p.RouteInfo { - DefaultRouter.UnregisterRoute(routeInfo.RouteId) - delete(p.RouteInfo, authToken) - } -} - -func (p *WshRpcMultiProxy) getRouteInfo(authToken string) *multiProxyRouteInfo { - p.Lock.Lock() - defer p.Lock.Unlock() - return p.RouteInfo[authToken] -} - -func (p *WshRpcMultiProxy) setRouteInfo(authToken string, routeInfo *multiProxyRouteInfo) { - p.Lock.Lock() - defer p.Lock.Unlock() - p.RouteInfo[authToken] = routeInfo -} - -func (p *WshRpcMultiProxy) removeRouteInfo(authToken string) { - p.Lock.Lock() - defer p.Lock.Unlock() - delete(p.RouteInfo, authToken) -} - -func (p *WshRpcMultiProxy) sendResponseError(msg RpcMessage, sendErr error) { - if msg.ReqId == "" { - // no response needed - return - } - resp := RpcMessage{ - ResId: msg.ReqId, - Error: sendErr.Error(), - } - respBytes, _ := json.Marshal(resp) - p.ToRemoteCh <- respBytes -} - -func (p *WshRpcMultiProxy) sendAuthResponse(msg RpcMessage, routeId string, authToken string) { - if msg.ReqId == "" { - // no response needed - return - } - resp := RpcMessage{ - ResId: msg.ReqId, - Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId, AuthToken: authToken}, - } - respBytes, _ := json.Marshal(resp) - p.ToRemoteCh <- respBytes -} - -func (p *WshRpcMultiProxy) handleUnauthMessage(msgBytes []byte) { - var msg RpcMessage - err := json.Unmarshal(msgBytes, &msg) - if err != nil { - // nothing to do here, malformed message - return - } - if msg.Command == wshrpc.Command_Authenticate { - rpcContext, routeId, err := handleAuthenticationCommand(msg) - if err != nil { - p.sendResponseError(msg, err) - return - } - routeInfo := &multiProxyRouteInfo{ - RouteId: routeId, - AuthToken: uuid.New().String(), - RpcContext: rpcContext, - } - routeInfo.Proxy = MakeRpcProxy() - routeInfo.Proxy.SetRpcContext(rpcContext) - p.setRouteInfo(routeInfo.AuthToken, routeInfo) - p.sendAuthResponse(msg, routeId, routeInfo.AuthToken) - go func() { - defer func() { - panichandler.PanicHandler("WshRpcMultiProxy:handleUnauthMessage", recover()) - }() - for msgBytes := range routeInfo.Proxy.ToRemoteCh { - p.ToRemoteCh <- msgBytes - } - }() - DefaultRouter.RegisterRoute(routeId, routeInfo.Proxy, true) - return - } - // TODO implement authenticatetoken for multiproxy unauth message - if msg.AuthToken == "" { - p.sendResponseError(msg, fmt.Errorf("no auth token")) - return - } - routeInfo := p.getRouteInfo(msg.AuthToken) - if routeInfo == nil { - p.sendResponseError(msg, fmt.Errorf("invalid auth token")) - return - } - if msg.Command != "" && msg.Source != routeInfo.RouteId { - p.sendResponseError(msg, fmt.Errorf("invalid source route for auth token")) - return - } - if msg.Command == wshrpc.Command_Dispose { - DefaultRouter.UnregisterRoute(routeInfo.RouteId) - p.removeRouteInfo(msg.AuthToken) - close(routeInfo.Proxy.ToRemoteCh) - close(routeInfo.Proxy.FromRemoteCh) - return - } - routeInfo.Proxy.FromRemoteCh <- msgBytes -} - -func (p *WshRpcMultiProxy) RunUnauthLoop() { - // loop over unauthenticated message - // handle Authenicate commands, and pass authenticated messages to the AuthCh - for msgBytes := range p.FromRemoteRawCh { - p.handleUnauthMessage(msgBytes) - } -} diff --git a/pkg/wshutil/wshproxy.go b/pkg/wshutil/wshproxy.go index be53a41b08..4a8126789c 100644 --- a/pkg/wshutil/wshproxy.go +++ b/pkg/wshutil/wshproxy.go @@ -4,14 +4,11 @@ package wshutil import ( - "encoding/json" "fmt" "sync" - "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" - "github.com/wavetermdev/waveterm/pkg/util/shellutil" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) @@ -19,237 +16,31 @@ type WshRpcProxy struct { Lock *sync.Mutex RpcContext *wshrpc.RpcContext ToRemoteCh chan []byte - FromRemoteCh chan []byte - AuthToken string + FromRemoteCh chan baseds.RpcInputChType + PeerInfo string } -func MakeRpcProxy() *WshRpcProxy { +func MakeRpcProxy(peerInfo string) *WshRpcProxy { return &WshRpcProxy{ Lock: &sync.Mutex{}, ToRemoteCh: make(chan []byte, DefaultInputChSize), - FromRemoteCh: make(chan []byte, DefaultOutputChSize), + FromRemoteCh: make(chan baseds.RpcInputChType, DefaultOutputChSize), + PeerInfo: peerInfo, } } -func (p *WshRpcProxy) SetRpcContext(rpcCtx *wshrpc.RpcContext) { - p.Lock.Lock() - defer p.Lock.Unlock() - p.RpcContext = rpcCtx -} - -func (p *WshRpcProxy) GetRpcContext() *wshrpc.RpcContext { - p.Lock.Lock() - defer p.Lock.Unlock() - return p.RpcContext -} - -func (p *WshRpcProxy) SetAuthToken(authToken string) { - p.Lock.Lock() - defer p.Lock.Unlock() - p.AuthToken = authToken +func (p *WshRpcProxy) GetPeerInfo() string { + return p.PeerInfo } -func (p *WshRpcProxy) GetAuthToken() string { +func (p *WshRpcProxy) SetPeerInfo(peerInfo string) { p.Lock.Lock() defer p.Lock.Unlock() - return p.AuthToken -} - -func (p *WshRpcProxy) sendResponseError(msg RpcMessage, sendErr error) { - if msg.ReqId == "" { - // no response needed - return - } - resp := RpcMessage{ - ResId: msg.ReqId, - Route: msg.Source, - Error: sendErr.Error(), - } - respBytes, _ := json.Marshal(resp) - p.SendRpcMessage(respBytes, "resp-error") -} - -func (p *WshRpcProxy) sendAuthenticateResponse(msg RpcMessage, routeId string) { - if msg.ReqId == "" { - // no response needed - return - } - resp := RpcMessage{ - ResId: msg.ReqId, - Route: msg.Source, - Data: wshrpc.CommandAuthenticateRtnData{RouteId: routeId}, - } - respBytes, _ := json.Marshal(resp) - p.SendRpcMessage(respBytes, "auth-resp") -} - -func (p *WshRpcProxy) sendAuthenticateTokenResponse(msg RpcMessage, entry *shellutil.TokenSwapEntry) { - if msg.ReqId == "" { - // no response needed - return - } - routeId, _ := MakeRouteIdFromCtx(entry.RpcContext) // already validated so don't need to check error - resp := RpcMessage{ - ResId: msg.ReqId, - Route: msg.Source, - Data: wshrpc.CommandAuthenticateRtnData{ - RouteId: routeId, - Env: entry.Env, - InitScriptText: entry.ScriptText, - }, - } - respBytes, _ := json.Marshal(resp) - p.SendRpcMessage(respBytes, "auth-token-resp") -} - -func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) { - if newCtx == nil { - return "", fmt.Errorf("no context found in jwt token") - } - if newCtx.BlockId == "" && newCtx.Conn == "" { - return "", fmt.Errorf("no blockid or conn found in jwt token") - } - if newCtx.BlockId != "" { - if _, err := uuid.Parse(newCtx.BlockId); err != nil { - return "", fmt.Errorf("invalid blockId in jwt token") - } - } - routeId, err := MakeRouteIdFromCtx(newCtx) - if err != nil { - return "", fmt.Errorf("error making routeId from context: %w", err) - } - return routeId, nil -} - -func handleAuthenticationCommand(msg RpcMessage) (*wshrpc.RpcContext, string, error) { - if msg.Data == nil { - return nil, "", fmt.Errorf("no data in authenticate message") - } - strData, ok := msg.Data.(string) - if !ok { - return nil, "", fmt.Errorf("data in authenticate message not a string") - } - newCtx, err := ValidateAndExtractRpcContextFromToken(strData) - if err != nil { - return nil, "", fmt.Errorf("error validating token: %w", err) - } - routeId, err := validateRpcContextFromAuth(newCtx) - if err != nil { - return nil, "", err - } - return newCtx, routeId, nil -} - -func handleAuthenticateTokenCommand(msg RpcMessage) (*shellutil.TokenSwapEntry, error) { - if msg.Data == nil { - return nil, fmt.Errorf("no data in authenticatetoken message") - } - var tokenData wshrpc.CommandAuthenticateTokenData - err := utilfn.ReUnmarshal(&tokenData, msg.Data) - if err != nil { - return nil, fmt.Errorf("error unmarshalling token data: %w", err) - } - if tokenData.Token == "" { - return nil, fmt.Errorf("no token in authenticatetoken message") - } - entry := shellutil.GetAndRemoveTokenSwapEntry(tokenData.Token) - if entry == nil { - return nil, fmt.Errorf("no token entry found") - } - _, err = validateRpcContextFromAuth(entry.RpcContext) - if err != nil { - return nil, err - } - return entry, nil -} - -// runs on the client (stdio client) -func (p *WshRpcProxy) HandleClientProxyAuth(router *WshRouter) (string, error) { - for { - msgBytes, ok := <-p.FromRemoteCh - if !ok { - return "", fmt.Errorf("remote closed, not authenticated") - } - var origMsg RpcMessage - err := json.Unmarshal(msgBytes, &origMsg) - if err != nil { - // nothing to do, can't even send a response since we don't have Source or ReqId - continue - } - if origMsg.Command == "" { - // this message is not allowed (protocol error at this point), ignore - continue - } - if origMsg.Command == wshrpc.Command_Authenticate { - authRtn, err := router.HandleProxyAuth(origMsg.Data) - if err != nil { - respErr := fmt.Errorf("error handling proxy auth: %w", err) - p.sendResponseError(origMsg, respErr) - return "", respErr - } - p.SetAuthToken(authRtn.AuthToken) - announceMsg := RpcMessage{ - Command: wshrpc.Command_RouteAnnounce, - Source: authRtn.RouteId, - AuthToken: authRtn.AuthToken, - } - announceBytes, _ := json.Marshal(announceMsg) - router.InjectMessage(announceBytes, authRtn.RouteId) - p.sendAuthenticateResponse(origMsg, authRtn.RouteId) - return authRtn.RouteId, nil - } - if origMsg.Command == wshrpc.Command_AuthenticateToken { - // TODO implement authenticatetoken for proxyauth - } - respErr := fmt.Errorf("connection not authenticated") - p.sendResponseError(origMsg, respErr) - continue - } -} - -// runs on the server -func (p *WshRpcProxy) HandleAuthentication() (*wshrpc.RpcContext, error) { - for { - msgBytes, ok := <-p.FromRemoteCh - if !ok { - return nil, fmt.Errorf("remote closed, not authenticated") - } - var msg RpcMessage - err := json.Unmarshal(msgBytes, &msg) - if err != nil { - // nothing to do, can't even send a response since we don't have Source or ReqId - continue - } - if msg.Command == "" { - // this message is not allowed (protocol error at this point), ignore - continue - } - if msg.Command == wshrpc.Command_Authenticate { - newCtx, routeId, err := handleAuthenticationCommand(msg) - if err != nil { - p.sendResponseError(msg, err) - continue - } - p.sendAuthenticateResponse(msg, routeId) - return newCtx, nil - } - if msg.Command == wshrpc.Command_AuthenticateToken { - entry, err := handleAuthenticateTokenCommand(msg) - if err != nil { - p.sendResponseError(msg, err) - continue - } - p.sendAuthenticateTokenResponse(msg, entry) - return entry.RpcContext, nil - } - respErr := fmt.Errorf("connection not authenticated") - p.sendResponseError(msg, respErr) - continue - } + p.PeerInfo = peerInfo } // TODO: Figure out who is sending to closed routes and why we're not catching it -func (p *WshRpcProxy) SendRpcMessage(msg []byte, debugStr string) { +func (p *WshRpcProxy) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) { defer func() { panicCtx := "WshRpcProxy.SendRpcMessage" if debugStr != "" { @@ -261,31 +52,6 @@ func (p *WshRpcProxy) SendRpcMessage(msg []byte, debugStr string) { } func (p *WshRpcProxy) RecvRpcMessage() ([]byte, bool) { - msgBytes, more := <-p.FromRemoteCh - authToken := p.GetAuthToken() - if !more || (p.RpcContext == nil && authToken == "") { - return msgBytes, more - } - var msg RpcMessage - err := json.Unmarshal(msgBytes, &msg) - if err != nil { - // nothing to do here -- will error out at another level - return msgBytes, true - } - if p.RpcContext != nil { - msg.Data, err = recodeCommandData(msg.Command, msg.Data, p.RpcContext) - if err != nil { - // nothing to do here -- will error out at another level - return msgBytes, true - } - } - if msg.AuthToken == "" { - msg.AuthToken = authToken - } - newBytes, err := json.Marshal(msg) - if err != nil { - // nothing to do here - return msgBytes, true - } - return newBytes, true + inputVal, more := <-p.FromRemoteCh + return inputVal.MsgBytes, more } diff --git a/pkg/wshutil/wshrouter.go b/pkg/wshutil/wshrouter.go index 80c6a038a6..cbc2f47ab3 100644 --- a/pkg/wshutil/wshrouter.go +++ b/pkg/wshutil/wshrouter.go @@ -9,21 +9,24 @@ import ( "errors" "fmt" "log" + "strings" "sync" "time" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" - "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) const ( - DefaultRoute = "wavesrv" - UpstreamRoute = "upstream" - SysRoute = "sys" // this route doesn't exist, just a placeholder for system messages - ElectronRoute = "electron" + DefaultRoute = "wavesrv" + ElectronRoute = "electron" + ControlRoute = "$control" // control plane route + ControlRootRoute = "$control:root" // control plane route to root router + + ControlPrefix = "$" RoutePrefix_Conn = "conn:" RoutePrefix_Controller = "controller:" @@ -43,19 +46,48 @@ type routeInfo struct { DestRouteId string } -type msgAndRoute struct { - msgBytes []byte - fromRouteId string +const LinkKind_Leaf = "leaf" +const LinkKind_Router = "router" + +type linkMeta struct { + linkId baseds.LinkId + trusted bool + linkKind string + sourceRouteId string + client AbstractRpcClient +} + +func (lm *linkMeta) Name() string { + return fmt.Sprintf("%d#[%s]", lm.linkId, lm.client.GetPeerInfo()) +} + +type rpcRoutingInfo struct { + rpcId string + sourceLinkId baseds.LinkId + destRouteId string +} + +type messageWrap struct { + msgBytes []byte + debugStr string } type WshRouter struct { - Lock *sync.Mutex - RouteMap map[string]AbstractRpcClient // routeid => client - UpstreamClient AbstractRpcClient // upstream client (if we are not the terminal router) - AnnouncedRoutes map[string]string // routeid => local routeid - RpcMap map[string]*routeInfo // rpcid => routeinfo - SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel - InputCh chan msgAndRoute + lock *sync.Mutex + isRootRouter bool + nextLinkId baseds.LinkId + upstreamLinkId baseds.LinkId + inputCh chan baseds.RpcInputChType + rpcMap map[string]rpcRoutingInfo // rpcid => routeinfo + routeMap map[string]baseds.LinkId // routeid => linkid + linkMap map[baseds.LinkId]*linkMeta + + upstreamBufLock sync.Mutex + upstreamBufCond *sync.Cond + upstreamBuf []messageWrap + upstreamLoopStarted bool + + controlRpc *WshRpc } func MakeConnectionRouteId(connId string) string { @@ -70,6 +102,10 @@ func MakeProcRouteId(procId string) string { return "proc:" + procId } +func MakeRandomProcRouteId() string { + return MakeProcRouteId(uuid.New().String()) +} + func MakeTabRouteId(tabId string) string { return "tab:" + tabId } @@ -82,21 +118,43 @@ func MakeBuilderRouteId(builderId string) string { return "builder:" + builderId } -var DefaultRouter = NewWshRouter() +var DefaultRouter *WshRouter func NewWshRouter() *WshRouter { rtn := &WshRouter{ - Lock: &sync.Mutex{}, - RouteMap: make(map[string]AbstractRpcClient), - AnnouncedRoutes: make(map[string]string), - RpcMap: make(map[string]*routeInfo), - SimpleRequestMap: make(map[string]chan *RpcMessage), - InputCh: make(chan msgAndRoute, DefaultInputChSize), - } + lock: &sync.Mutex{}, + nextLinkId: 0, + upstreamLinkId: baseds.NoLinkId, + inputCh: make(chan baseds.RpcInputChType), + rpcMap: make(map[string]rpcRoutingInfo), + linkMap: make(map[baseds.LinkId]*linkMeta), + routeMap: make(map[string]baseds.LinkId), + } + rtn.upstreamBufCond = sync.NewCond(&rtn.upstreamBufLock) + rtn.registerControlPlane() go rtn.runServer() return rtn } +func (router *WshRouter) IsRootRouter() bool { + router.lock.Lock() + defer router.lock.Unlock() + return router.isRootRouter +} + +func (router *WshRouter) SetAsRootRouter() { + router.lock.Lock() + defer router.lock.Unlock() + router.isRootRouter = true + + // also bind $control:root to the control RPC + linkId := router.routeMap[ControlRoute] + if linkId != baseds.NoLinkId { + router.routeMap[ControlRootRoute] = linkId + log.Printf("wshrouter registered control:root route linkid=%d", linkId) + } +} + func noRouteErr(routeId string) error { if routeId == "" { return errors.New("no default route") @@ -108,8 +166,8 @@ func (router *WshRouter) SendEvent(routeId string, event wps.WaveEvent) { defer func() { panichandler.PanicHandler("WshRouter.SendEvent", recover()) }() - rpc := router.GetRpc(routeId) - if rpc == nil { + lm := router.getLinkForRoute(routeId) + if lm == nil { return } msg := RpcMessage{ @@ -122,10 +180,14 @@ func (router *WshRouter) SendEvent(routeId string, event wps.WaveEvent) { // nothing to do return } - rpc.SendRpcMessage(msgBytes, "eventrecv") + lm.client.SendRpcMessage(msgBytes, baseds.NoLinkId, "eventrecv") } -func (router *WshRouter) handleNoRoute(msg RpcMessage) { +func (router *WshRouter) handleNoRoute(msg RpcMessage, ingressLinkId baseds.LinkId) { + lm := router.getLinkMeta(ingressLinkId) + if lm == nil { + return + } nrErr := noRouteErr(msg.Route) if msg.ReqId == "" { if msg.Command == wshrpc.Command_Message { @@ -133,9 +195,14 @@ func (router *WshRouter) handleNoRoute(msg RpcMessage) { return } // no response needed, but send message back to source - respMsg := RpcMessage{Command: wshrpc.Command_Message, Route: msg.Source, Data: wshrpc.CommandMessageData{Message: nrErr.Error()}} + respMsg := RpcMessage{ + Command: wshrpc.Command_Message, + Route: msg.Source, + Source: ControlRoute, + Data: wshrpc.CommandMessageData{Message: nrErr.Error()}, + } respBytes, _ := json.Marshal(respMsg) - router.InputCh <- msgAndRoute{msgBytes: respBytes, fromRouteId: SysRoute} + lm.client.SendRpcMessage(respBytes, baseds.NoLinkId, "no-route-err") return } // send error response @@ -144,86 +211,70 @@ func (router *WshRouter) handleNoRoute(msg RpcMessage) { Error: nrErr.Error(), } respBytes, _ := json.Marshal(response) - router.sendRoutedMessage(respBytes, msg.Source) + router.sendRoutedMessage(respBytes, msg.Source, msg.Command, baseds.NoLinkId) } -func (router *WshRouter) registerRouteInfo(rpcId string, sourceRouteId string, destRouteId string) { +func (router *WshRouter) registerRouteInfo(rpcId string, sourceLinkId baseds.LinkId, destRouteId string) { if rpcId == "" { return } - router.Lock.Lock() - defer router.Lock.Unlock() - router.RpcMap[rpcId] = &routeInfo{RpcId: rpcId, SourceRouteId: sourceRouteId, DestRouteId: destRouteId} + router.lock.Lock() + defer router.lock.Unlock() + router.rpcMap[rpcId] = rpcRoutingInfo{ + rpcId: rpcId, + sourceLinkId: sourceLinkId, + destRouteId: destRouteId, + } } func (router *WshRouter) unregisterRouteInfo(rpcId string) { - router.Lock.Lock() - defer router.Lock.Unlock() - delete(router.RpcMap, rpcId) -} - -func (router *WshRouter) getRouteInfo(rpcId string) *routeInfo { - router.Lock.Lock() - defer router.Lock.Unlock() - return router.RpcMap[rpcId] -} - -func (router *WshRouter) handleAnnounceMessage(msg RpcMessage, input msgAndRoute) { - if msg.Source != input.fromRouteId { - router.Lock.Lock() - router.AnnouncedRoutes[msg.Source] = input.fromRouteId - router.Lock.Unlock() - } - upstream := router.GetUpstreamClient() - if upstream != nil { - upstream.SendRpcMessage(input.msgBytes, "announce-upstream") - } + router.lock.Lock() + defer router.lock.Unlock() + delete(router.rpcMap, rpcId) } -func (router *WshRouter) handleUnannounceMessage(msg RpcMessage, input msgAndRoute) { - router.Lock.Lock() - delete(router.AnnouncedRoutes, msg.Source) - router.Lock.Unlock() - - upstream := router.GetUpstreamClient() - if upstream != nil { - upstream.SendRpcMessage(input.msgBytes, "unannounce-upstream") +func (router *WshRouter) getRouteInfo(rpcId string) *rpcRoutingInfo { + router.lock.Lock() + defer router.lock.Unlock() + rtn, ok := router.rpcMap[rpcId] + if !ok { + return nil } -} - -func (router *WshRouter) getAnnouncedRoute(routeId string) string { - router.Lock.Lock() - defer router.Lock.Unlock() - return router.AnnouncedRoutes[routeId] + return &rtn } // returns true if message was sent, false if failed -func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string) bool { - rpc := router.GetRpc(routeId) - if rpc != nil { - rpc.SendRpcMessage(msgBytes, "route") +func (router *WshRouter) sendRoutedMessage(msgBytes []byte, routeId string, commandName string, ingressLinkId baseds.LinkId) bool { + lm := router.getLinkForRoute(routeId) + if lm != nil { + lm.client.SendRpcMessage(msgBytes, ingressLinkId, "route") return true } - localRouteId := router.getAnnouncedRoute(routeId) - if localRouteId != "" { - rpc := router.GetRpc(localRouteId) - if rpc != nil { - rpc.SendRpcMessage(msgBytes, "route-local") - return true - } - } - upstream := router.GetUpstreamClient() + upstream := router.getUpstreamClient() if upstream != nil { - upstream.SendRpcMessage(msgBytes, "route-upstream") + upstream.SendRpcMessage(msgBytes, ingressLinkId, "route-upstream") return true } - log.Printf("[router] no rpc for route id %q\n", routeId) + if commandName != "" { + log.Printf("[router] no rpc for route id %q command:%s\n", routeId, commandName) + } else { + log.Printf("[router] no rpc for route id %q\n", routeId) + } return false } +func (router *WshRouter) sendMessageToLink(msgBytes []byte, linkId baseds.LinkId, ingressLinkId baseds.LinkId) bool { + lm := router.getLinkMeta(linkId) + if lm == nil { + return false + } + lm.client.SendRpcMessage(msgBytes, ingressLinkId, "link") + return true +} + func (router *WshRouter) runServer() { - for input := range router.InputCh { - msgBytes := input.msgBytes + for input := range router.inputCh { + msgBytes := input.MsgBytes var msg RpcMessage err := json.Unmarshal(msgBytes, &msg) if err != nil { @@ -231,22 +282,14 @@ func (router *WshRouter) runServer() { continue } routeId := msg.Route - if msg.Command == wshrpc.Command_RouteAnnounce { - router.handleAnnounceMessage(msg, input) - continue - } - if msg.Command == wshrpc.Command_RouteUnannounce { - router.handleUnannounceMessage(msg, input) - continue - } if msg.Command != "" { // new comand, setup new rpc - ok := router.sendRoutedMessage(msgBytes, routeId) + ok := router.sendRoutedMessage(msgBytes, routeId, msg.Command, input.IngressLinkId) if !ok { - router.handleNoRoute(msg) + router.handleNoRoute(msg, input.IngressLinkId) continue } - router.registerRouteInfo(msg.ReqId, msg.Source, routeId) + router.registerRouteInfo(msg.ReqId, input.IngressLinkId, routeId) continue } // look at reqid or resid to route correctly @@ -257,19 +300,15 @@ func (router *WshRouter) runServer() { continue } // no need to check the return value here (noop if failed) - router.sendRoutedMessage(msgBytes, routeInfo.DestRouteId) + router.sendRoutedMessage(msgBytes, routeInfo.destRouteId, "", input.IngressLinkId) continue } else if msg.ResId != "" { - ok := router.trySimpleResponse(&msg) - if ok { - continue - } routeInfo := router.getRouteInfo(msg.ResId) if routeInfo == nil { // no route info, nothing to do continue } - router.sendRoutedMessage(msgBytes, routeInfo.SourceRouteId) + router.sendMessageToLink(msgBytes, routeInfo.sourceLinkId, input.IngressLinkId) if !msg.Cont { router.unregisterRouteInfo(msg.ResId) } @@ -283,10 +322,7 @@ func (router *WshRouter) runServer() { func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) error { for { - if router.GetRpc(routeId) != nil { - return nil - } - if router.getAnnouncedRoute(routeId) != "" { + if router.getLinkForRoute(routeId) != nil { return nil } select { @@ -298,195 +334,381 @@ func (router *WshRouter) WaitForRegister(ctx context.Context, routeId string) er } } -// this will also consume the output channel of the abstract client -func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, shouldAnnounce bool) { - if routeId == SysRoute || routeId == UpstreamRoute { - // cannot register sys route - log.Printf("error: WshRouter cannot register %s route\n", routeId) +// this will never block, can be called while holding router.Lock +func (router *WshRouter) queueUpstreamMessage(msgBytes []byte, debugStr string) { + if router.getUpstreamClient() == nil { return } - log.Printf("[router] registering wsh route %q\n", routeId) - router.Lock.Lock() - defer router.Lock.Unlock() - alreadyExists := router.RouteMap[routeId] != nil - if alreadyExists { - log.Printf("[router] warning: route %q already exists (replacing)\n", routeId) - } - router.RouteMap[routeId] = rpc - go func() { - defer func() { - panichandler.PanicHandler("WshRouter:registerRoute:recvloop", recover()) - }() - // announce - if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil { - announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId} - announceBytes, _ := json.Marshal(announceMsg) - router.GetUpstreamClient().SendRpcMessage(announceBytes, "route-announce") + router.upstreamBufLock.Lock() + defer router.upstreamBufLock.Unlock() + router.upstreamBuf = append(router.upstreamBuf, messageWrap{msgBytes: msgBytes, debugStr: debugStr}) + if !router.upstreamLoopStarted { + router.upstreamLoopStarted = true + go router.runUpstreamBufferLoop() + } + router.upstreamBufCond.Signal() +} + +func (router *WshRouter) runUpstreamBufferLoop() { + defer func() { + panichandler.PanicHandler("WshRouter:runUpstreamBufferLoop", recover()) + }() + for { + router.upstreamBufLock.Lock() + for len(router.upstreamBuf) == 0 { + router.upstreamBufCond.Wait() } - for { - msgBytes, ok := rpc.RecvRpcMessage() - if !ok { - break + msg := router.upstreamBuf[0] + router.upstreamBuf = router.upstreamBuf[1:] + router.upstreamBufLock.Unlock() + + upstream := router.getUpstreamClient() + if upstream != nil { + upstream.SendRpcMessage(msg.msgBytes, baseds.NoLinkId, msg.debugStr) + } + } +} + +func (router *WshRouter) RegisterUntrustedLink(client AbstractRpcClient) baseds.LinkId { + router.lock.Lock() + defer router.lock.Unlock() + router.nextLinkId++ + linkId := router.nextLinkId + lm := &linkMeta{ + linkId: linkId, + trusted: false, + client: client, + } + log.Printf("wshrouter register link %s", lm.Name()) + router.linkMap[linkId] = lm + go router.runLinkClientRecvLoop(linkId, client) + return linkId +} + +func (router *WshRouter) trustLink(linkId baseds.LinkId, linkKind string) { + router.lock.Lock() + defer router.lock.Unlock() + lm := router.linkMap[linkId] + if lm == nil { + return + } + log.Printf("wshrouter trust link %s kind=%s", lm.Name(), linkKind) + lm.trusted = true + lm.linkKind = linkKind +} + +func (router *WshRouter) runLinkClientRecvLoop(linkId baseds.LinkId, client AbstractRpcClient) { + defer func() { + panichandler.PanicHandler("WshRouter:runLinkClientRecvLoop", recover()) + }() + exitReason := "unknown" + lmForLog := router.getLinkMeta(linkId) + linkName := fmt.Sprintf("%d", linkId) + if lmForLog != nil { + linkName = lmForLog.Name() + } + log.Printf("link recvloop start for %s", linkName) + defer log.Printf("link recvloop done for %s (%s)", linkName, exitReason) + for { + msgBytes, ok := client.RecvRpcMessage() + if !ok { + exitReason = "recv-eof" + break + } + var rpcMsg RpcMessage + err := json.Unmarshal(msgBytes, &rpcMsg) + if err != nil { + continue + } + lm := router.getLinkMeta(linkId) + if lm == nil { + exitReason = "link-gone" + break + } + if rpcMsg.IsRpcRequest() { + if lm.sourceRouteId != "" { + rpcMsg.Source = lm.sourceRouteId + } + if rpcMsg.Route == "" { + rpcMsg.Route = DefaultRoute } - var rpcMsg RpcMessage - err := json.Unmarshal(msgBytes, &rpcMsg) + msgBytes, err = json.Marshal(rpcMsg) if err != nil { continue } - if rpcMsg.Command != "" { - if rpcMsg.Source == "" { - rpcMsg.Source = routeId - } - if rpcMsg.Route == "" { - rpcMsg.Route = DefaultRoute - } - msgBytes, err = json.Marshal(rpcMsg) - if err != nil { + // allow control routes even for untrusted links (for authentication) + isControlRoute := rpcMsg.Route == ControlRoute || rpcMsg.Route == ControlRootRoute + if !lm.trusted { + if !isControlRoute { + sendControlUnauthenticatedErrorResponse(rpcMsg, *lm) continue } + log.Printf("wshrouter control-msg route=%s link=%s command=%s source=%s", rpcMsg.Route, lm.Name(), rpcMsg.Command, rpcMsg.Source) + } + } else { + // non-request messages (responses) + if !lm.trusted { + // drop responses from untrusted links + continue } - router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: routeId} } - }() + router.inputCh <- baseds.RpcInputChType{MsgBytes: msgBytes, IngressLinkId: linkId} + } } -func (router *WshRouter) UnregisterRoute(routeId string) { - log.Printf("[router] unregistering wsh route %q\n", routeId) - router.Lock.Lock() - delete(router.RouteMap, routeId) - // clear out announced routes - for announcedRouteId, localRouteId := range router.AnnouncedRoutes { - if localRouteId == routeId { - delete(router.AnnouncedRoutes, announcedRouteId) - } +// synchronized, returns a copy +func (router *WshRouter) getLinkMeta(linkId baseds.LinkId) *linkMeta { + if linkId == baseds.NoLinkId { + return nil } - upstream := router.UpstreamClient - router.Lock.Unlock() - - if upstream != nil { - unannounceMsg := RpcMessage{Command: wshrpc.Command_RouteUnannounce, Source: routeId} - unannounceBytes, _ := json.Marshal(unannounceMsg) - upstream.SendRpcMessage(unannounceBytes, "route-unannounce") - } - - go func() { - defer func() { - panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover()) - }() - wps.Broker.UnsubscribeAll(routeId) - wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}}) - }() + router.lock.Lock() + defer router.lock.Unlock() + lm := router.linkMap[linkId] + if lm == nil { + return nil + } + lmCopy := *lm + return &lmCopy } -// this may return nil (returns default only for empty routeId) -func (router *WshRouter) GetRpc(routeId string) AbstractRpcClient { - router.Lock.Lock() - defer router.Lock.Unlock() - return router.RouteMap[routeId] +// synchronized, returns a copy +func (router *WshRouter) getLinkForRoute(routeId string) *linkMeta { + if routeId == "" { + return nil + } + router.lock.Lock() + defer router.lock.Unlock() + linkId := router.routeMap[routeId] + if linkId == baseds.NoLinkId { + return nil + } + lm := router.linkMap[linkId] + if lm == nil { + return nil + } + lmCopy := *lm + return &lmCopy } -func (router *WshRouter) SetUpstreamClient(rpc AbstractRpcClient) { - router.Lock.Lock() - defer router.Lock.Unlock() - router.UpstreamClient = rpc +func (router *WshRouter) GetLinkIdForRoute(routeId string) baseds.LinkId { + lm := router.getLinkForRoute(routeId) + if lm == nil { + return baseds.NoLinkId + } + return lm.linkId } -func (router *WshRouter) GetUpstreamClient() AbstractRpcClient { - router.Lock.Lock() - defer router.Lock.Unlock() - return router.UpstreamClient +// only for leaves +func (router *WshRouter) RegisterTrustedLeaf(rpc AbstractRpcClient, routeId string) (baseds.LinkId, error) { + if !isBindableRouteId(routeId) { + return 0, fmt.Errorf("invalid routeid %q", routeId) + } + linkId := router.RegisterUntrustedLink(rpc) + router.trustLink(linkId, LinkKind_Leaf) + router.bindRoute(linkId, routeId, true) + return linkId, nil } -func (router *WshRouter) InjectMessage(msgBytes []byte, fromRouteId string) { - router.InputCh <- msgAndRoute{msgBytes: msgBytes, fromRouteId: fromRouteId} +// only for routers +func (router *WshRouter) RegisterTrustedRouter(rpc AbstractRpcClient) baseds.LinkId { + linkId := router.RegisterUntrustedLink(rpc) + router.trustLink(linkId, LinkKind_Router) + return linkId } -func (router *WshRouter) registerSimpleRequest(reqId string) chan *RpcMessage { - router.Lock.Lock() - defer router.Lock.Unlock() - rtn := make(chan *RpcMessage, 1) - router.SimpleRequestMap[reqId] = rtn - return rtn +func (router *WshRouter) RegisterUpstream(rpc AbstractRpcClient) baseds.LinkId { + if router.IsRootRouter() { + panic("cannot register upstream for root router") + } + linkId := router.RegisterUntrustedLink(rpc) + router.trustLink(linkId, LinkKind_Router) + router.lock.Lock() + defer router.lock.Unlock() + router.upstreamLinkId = linkId + return linkId } -func (router *WshRouter) trySimpleResponse(msg *RpcMessage) bool { - router.Lock.Lock() - defer router.Lock.Unlock() - respCh := router.SimpleRequestMap[msg.ResId] - if respCh == nil { - return false +func (router *WshRouter) registerControlPlane() { + controlImpl := &WshRouterControlImpl{Router: router} + controlRpcCtx := wshrpc.RpcContext{RouteId: ControlRoute} + router.controlRpc = MakeWshRpc(controlRpcCtx, controlImpl, "control") + + linkId := router.RegisterUntrustedLink(router.controlRpc) + router.trustLink(linkId, LinkKind_Leaf) + + router.lock.Lock() + defer router.lock.Unlock() + lm := router.linkMap[linkId] + if lm != nil { + lm.sourceRouteId = ControlRoute + router.routeMap[ControlRoute] = linkId + log.Printf("wshrouter registered control route %q linkid=%d", ControlRoute, linkId) } - respCh <- msg - delete(router.SimpleRequestMap, msg.ResId) - return true } -func (router *WshRouter) clearSimpleRequest(reqId string) { - router.Lock.Lock() - defer router.Lock.Unlock() - delete(router.SimpleRequestMap, reqId) +func (router *WshRouter) announceUpstream(routeId string) { + msg := RpcMessage{ + Command: wshrpc.Command_RouteAnnounce, + Route: ControlRoute, + Source: routeId, + } + msgBytes, _ := json.Marshal(msg) + router.queueUpstreamMessage(msgBytes, "upstream-announce") } -func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage, fromRouteId string) (*RpcMessage, error) { - if msg.Command == "" { - return nil, errors.New("no command") +func (router *WshRouter) unannounceUpstream(routeId string) { + msg := RpcMessage{ + Command: wshrpc.Command_RouteUnannounce, + Route: ControlRoute, + Source: routeId, } - msgBytes, err := json.Marshal(msg) - if err != nil { - return nil, err - } - var respCh chan *RpcMessage - if msg.ReqId != "" { - respCh = router.registerSimpleRequest(msg.ReqId) - } - router.InjectMessage(msgBytes, fromRouteId) - if respCh == nil { - return nil, nil - } - select { - case <-ctx.Done(): - router.clearSimpleRequest(msg.ReqId) - return nil, ctx.Err() - case resp := <-respCh: - if resp.Error != "" { - return nil, errors.New(resp.Error) + msgBytes, _ := json.Marshal(msg) + router.queueUpstreamMessage(msgBytes, "upstream-unannounce") +} + +func (router *WshRouter) getRoutesForLink(linkId baseds.LinkId) []string { + router.lock.Lock() + defer router.lock.Unlock() + var routes []string + for routeId, mappedLinkId := range router.routeMap { + if mappedLinkId == linkId { + routes = append(routes, routeId) } - return resp, nil } + return routes } -func (router *WshRouter) HandleProxyAuth(jwtTokenAny any) (*wshrpc.CommandAuthenticateRtnData, error) { - if jwtTokenAny == nil { - return nil, errors.New("no jwt token") +func (router *WshRouter) UnregisterLink(linkId baseds.LinkId) { + routes := router.getRoutesForLink(linkId) + for _, routeId := range routes { + router.unbindRoute(linkId, routeId) } - jwtToken, ok := jwtTokenAny.(string) - if !ok { - return nil, errors.New("jwt token not a string") + router.lock.Lock() + defer router.lock.Unlock() + lm := router.linkMap[linkId] + if lm != nil { + log.Printf("wshrouter unregister link %s", lm.Name()) } - if jwtToken == "" { - return nil, errors.New("empty jwt token") + delete(router.linkMap, linkId) + if router.upstreamLinkId == linkId { + router.upstreamLinkId = baseds.NoLinkId } - msg := RpcMessage{ - Command: wshrpc.Command_Authenticate, - ReqId: uuid.New().String(), - Data: jwtToken, +} + +func isBindableRouteId(routeId string) bool { + if routeId == "" || strings.HasPrefix(routeId, ControlPrefix) { + return false + } + return true +} + +func (router *WshRouter) unbindRouteLocally(linkId baseds.LinkId, routeId string) error { + if linkId == baseds.NoLinkId { + return fmt.Errorf("cannot unbind %q to NoLinkId", routeId) + } + router.lock.Lock() + defer router.lock.Unlock() + if router.routeMap[routeId] == linkId { + delete(router.routeMap, routeId) } - ctx, cancelFn := context.WithTimeout(context.Background(), DefaultTimeoutMs*time.Millisecond) - defer cancelFn() - resp, err := router.RunSimpleRawCommand(ctx, msg, "") + return nil +} + +func (router *WshRouter) unbindRoute(linkId baseds.LinkId, routeId string) error { + err := router.unbindRouteLocally(linkId, routeId) if err != nil { - return nil, err + return err } - if resp == nil || resp.Data == nil { - return nil, errors.New("no data in authenticate response") + lm := router.getLinkMeta(linkId) + if lm != nil { + log.Printf("wshrouter unbind route %q from %s", routeId, lm.Name()) } - var respData wshrpc.CommandAuthenticateRtnData - err = utilfn.ReUnmarshal(&respData, resp.Data) + router.unannounceUpstream(routeId) + if router.IsRootRouter() { + router.unsubscribeFromBroker(routeId) + } + return nil +} + +func (router *WshRouter) bindRouteLocally(linkId baseds.LinkId, routeId string, isSourceRoute bool) error { + if linkId == baseds.NoLinkId { + return fmt.Errorf("cannot bindroute %q to NoLinkId", routeId) + } + if !isBindableRouteId(routeId) { + return fmt.Errorf("router cannot register %q route (invalid routeid)", routeId) + } + router.lock.Lock() + defer router.lock.Unlock() + lm := router.linkMap[linkId] + if lm == nil { + return fmt.Errorf("cannot bind route %q, no link with id %d found", routeId, linkId) + } + if !lm.trusted { + return fmt.Errorf("cannot bind route %q, link %d is not trusted", routeId, linkId) + } + if isSourceRoute { + if lm.linkKind != LinkKind_Leaf { + return fmt.Errorf("cannot bind source route %q to link %d (link is not a leaf)", routeId, linkId) + } + if lm.sourceRouteId != "" && lm.sourceRouteId != routeId { + return fmt.Errorf("cannot bind source route %q to link %d (link already has source route %q)", routeId, linkId, lm.sourceRouteId) + } + lm.sourceRouteId = routeId + } else { + if lm.linkKind != LinkKind_Router { + return fmt.Errorf("cannot bind route %q to link %d (link is not a router)", routeId, linkId) + } + } + router.routeMap[routeId] = linkId + return nil +} + +func (router *WshRouter) bindRoute(linkId baseds.LinkId, routeId string, isSourceRoute bool) error { + err := router.bindRouteLocally(linkId, routeId, isSourceRoute) if err != nil { - return nil, fmt.Errorf("error unmarshalling authenticate response: %v", err) + return err + } + lm := router.getLinkMeta(linkId) + if lm != nil { + log.Printf("wshrouter bind route %q to %s", routeId, lm.Name()) + } + // don't announce control routes upstream (they are local only) + if !strings.HasPrefix(routeId, ControlPrefix) { + router.announceUpstream(routeId) + } + return nil +} + +func (router *WshRouter) getUpstreamClient() AbstractRpcClient { + router.lock.Lock() + defer router.lock.Unlock() + if router.upstreamLinkId == baseds.NoLinkId { + return nil + } + lm := router.linkMap[router.upstreamLinkId] + if lm == nil { + return nil + } + return lm.client +} + +func (router *WshRouter) unsubscribeFromBroker(routeId string) { + defer func() { + panichandler.PanicHandler("WshRouter:unregisterRoute:routegone", recover()) + }() + wps.Broker.UnsubscribeAll(routeId) + wps.Broker.Publish(wps.WaveEvent{Event: wps.Event_RouteGone, Scopes: []string{routeId}}) +} + +func sendControlUnauthenticatedErrorResponse(cmdMsg RpcMessage, linkMeta linkMeta) { + if cmdMsg.ReqId == "" { + return } - if respData.AuthToken == "" { - return nil, errors.New("no auth token in authenticate response") + rtnMsg := RpcMessage{ + Source: ControlRoute, + ResId: cmdMsg.ReqId, + Error: fmt.Sprintf("link is unauthenticated (%s), cannot call %q", linkMeta.Name(), cmdMsg.Command), } - return &respData, nil + rtnBytes, _ := json.Marshal(rtnMsg) + linkMeta.client.SendRpcMessage(rtnBytes, baseds.NoLinkId, "unauthenticated") } diff --git a/pkg/wshutil/wshrouter_controlimpl.go b/pkg/wshutil/wshrouter_controlimpl.go new file mode 100644 index 0000000000..f6f557eabc --- /dev/null +++ b/pkg/wshutil/wshrouter_controlimpl.go @@ -0,0 +1,218 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshutil + +import ( + "context" + "fmt" + "log" + + "github.com/wavetermdev/waveterm/pkg/baseds" + "github.com/wavetermdev/waveterm/pkg/util/shellutil" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +type WshRouterControlImpl struct { + Router *WshRouter +} + +func (impl *WshRouterControlImpl) WshServerImpl() {} + +func (impl *WshRouterControlImpl) RouteAnnounceCommand(ctx context.Context) error { + source := GetRpcSourceFromContext(ctx) + if source == "" { + return fmt.Errorf("no source in routeannounce") + } + handler := GetRpcResponseHandlerFromContext(ctx) + if handler == nil { + return fmt.Errorf("no response handler in context") + } + linkId := handler.GetIngressLinkId() + if linkId == baseds.NoLinkId { + return fmt.Errorf("no ingress link found") + } + return impl.Router.bindRoute(linkId, source, false) +} + +func (impl *WshRouterControlImpl) RouteUnannounceCommand(ctx context.Context) error { + source := GetRpcSourceFromContext(ctx) + if source == "" { + return fmt.Errorf("no source in routeunannounce") + } + handler := GetRpcResponseHandlerFromContext(ctx) + if handler == nil { + return fmt.Errorf("no response handler in context") + } + linkId := handler.GetIngressLinkId() + if linkId == baseds.NoLinkId { + return fmt.Errorf("no ingress link found") + } + return impl.Router.unbindRoute(linkId, source) +} + +func (impl *WshRouterControlImpl) SetPeerInfoCommand(ctx context.Context, peerInfo string) error { + source := GetRpcSourceFromContext(ctx) + linkId := impl.Router.GetLinkIdForRoute(source) + if linkId == baseds.NoLinkId { + return fmt.Errorf("no link found for source route %q", source) + } + lm := impl.Router.getLinkMeta(linkId) + if lm == nil { + return fmt.Errorf("no link meta found for linkId %d", linkId) + } + if proxy, ok := lm.client.(*WshRpcProxy); ok { + proxy.SetPeerInfo(peerInfo) + return nil + } + return fmt.Errorf("setpeerinfo only valid for proxy connections") +} + +func (impl *WshRouterControlImpl) AuthenticateCommand(ctx context.Context, data string) (wshrpc.CommandAuthenticateRtnData, error) { + handler := GetRpcResponseHandlerFromContext(ctx) + if handler == nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no response handler in context") + } + linkId := handler.GetIngressLinkId() + if linkId == baseds.NoLinkId { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no ingress link found") + } + + newCtx, err := ValidateAndExtractRpcContextFromToken(data) + if err != nil { + log.Printf("wshrouter authenticate error linkid=%d: %v", linkId, err) + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("error validating token: %w", err) + } + routeId, err := validateRpcContextFromAuth(newCtx) + if err != nil { + return wshrpc.CommandAuthenticateRtnData{}, err + } + + rtnData := wshrpc.CommandAuthenticateRtnData{} + if newCtx.IsRouter { + log.Printf("wshrouter authenticate success linkid=%d (router)", linkId) + impl.Router.trustLink(linkId, LinkKind_Router) + } else { + log.Printf("wshrouter authenticate success linkid=%d routeid=%q", linkId, routeId) + impl.Router.trustLink(linkId, LinkKind_Leaf) + impl.Router.bindRoute(linkId, routeId, true) + } + + return rtnData, nil +} + +func (impl *WshRouterControlImpl) AuthenticateTokenCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { + handler := GetRpcResponseHandlerFromContext(ctx) + if handler == nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no response handler in context") + } + linkId := handler.GetIngressLinkId() + if linkId == baseds.NoLinkId { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no ingress link found") + } + + if data.Token == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message") + } + + var rtnData wshrpc.CommandAuthenticateRtnData + var rpcContext *wshrpc.RpcContext + if impl.Router.IsRootRouter() { + entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token) + if entry == nil { + log.Printf("wshrouter authenticate-token error linkid=%d: no token entry found", linkId) + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found") + } + _, err := validateRpcContextFromAuth(entry.RpcContext) + if err != nil { + return wshrpc.CommandAuthenticateRtnData{}, err + } + if entry.RpcContext.IsRouter { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") + } + if entry.RpcContext.RouteId == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") + } + rpcContext = entry.RpcContext + rtnData = wshrpc.CommandAuthenticateRtnData{ + Env: entry.Env, + InitScriptText: entry.ScriptText, + RpcContext: rpcContext, + } + } else { + wshRpc := GetWshRpcFromContext(ctx) + if wshRpc == nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no wshrpc in context") + } + respData, err := wshRpc.SendRpcRequest(wshrpc.Command_AuthenticateTokenVerify, data, &wshrpc.RpcOpts{Route: ControlRootRoute}) + if err != nil { + log.Printf("wshrouter authenticate-token error linkid=%d: failed to verify token: %v", linkId, err) + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("failed to verify token: %w", err) + } + err = utilfn.ReUnmarshal(&rtnData, respData) + if err != nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("failed to unmarshal response: %w", err) + } + rpcContext = rtnData.RpcContext + } + + if rpcContext == nil { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no rpccontext in token response") + } + log.Printf("wshrouter authenticate-token success linkid=%d routeid=%q", linkId, rpcContext.RouteId) + impl.Router.trustLink(linkId, LinkKind_Leaf) + impl.Router.bindRoute(linkId, rpcContext.RouteId, true) + + return rtnData, nil +} + +func (impl *WshRouterControlImpl) AuthenticateTokenVerifyCommand(ctx context.Context, data wshrpc.CommandAuthenticateTokenData) (wshrpc.CommandAuthenticateRtnData, error) { + if !impl.Router.IsRootRouter() { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("authenticatetokenverify can only be called on root router") + } + + if data.Token == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token in authenticatetoken message") + } + entry := shellutil.GetAndRemoveTokenSwapEntry(data.Token) + if entry == nil { + log.Printf("wshrouter authenticate-token-verify error: no token entry found") + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no token entry found") + } + _, err := validateRpcContextFromAuth(entry.RpcContext) + if err != nil { + return wshrpc.CommandAuthenticateRtnData{}, err + } + if entry.RpcContext.IsRouter { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("cannot auth router via token") + } + if entry.RpcContext.RouteId == "" { + return wshrpc.CommandAuthenticateRtnData{}, fmt.Errorf("no routeid") + } + + rtnData := wshrpc.CommandAuthenticateRtnData{ + Env: entry.Env, + InitScriptText: entry.ScriptText, + RpcContext: entry.RpcContext, + } + + log.Printf("wshrouter authenticate-token-verify success routeid=%q", entry.RpcContext.RouteId) + return rtnData, nil +} + +func validateRpcContextFromAuth(newCtx *wshrpc.RpcContext) (string, error) { + if newCtx == nil { + return "", fmt.Errorf("no context found in jwt token") + } + if newCtx.IsRouter && newCtx.RouteId != "" { + return "", fmt.Errorf("invalid context, router cannot have a routeid") + } + if !newCtx.IsRouter && newCtx.RouteId == "" { + return "", fmt.Errorf("invalid context, must have a routeid") + } + if newCtx.IsRouter { + return "", nil + } + return newCtx.RouteId, nil +} diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index ebfca5cc9f..7d94777193 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -16,6 +16,7 @@ import ( "time" "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/util/ds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" @@ -40,17 +41,17 @@ type ServerImpl interface { } type AbstractRpcClient interface { - SendRpcMessage(msg []byte, debugStr string) + GetPeerInfo() string + SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) RecvRpcMessage() ([]byte, bool) // blocking } type WshRpc struct { Lock *sync.Mutex - InputCh chan []byte + InputCh chan baseds.RpcInputChType OutputCh chan []byte CtxDoneCh chan string // for context cancellation, value is ResId RpcContext *atomic.Pointer[wshrpc.RpcContext] - AuthToken string RpcMap map[string]*rpcData ServerImpl ServerImpl EventListener *EventListener @@ -103,8 +104,12 @@ func GetRpcResponseHandlerFromContext(ctx context.Context) *RpcResponseHandler { return rtn.(*RpcResponseHandler) } -func (w *WshRpc) SendRpcMessage(msg []byte, debugStr string) { - w.InputCh <- msg +func (w *WshRpc) GetPeerInfo() string { + return w.DebugName +} + +func (w *WshRpc) SendRpcMessage(msg []byte, ingressLinkId baseds.LinkId, debugStr string) { + w.InputCh <- baseds.RpcInputChType{MsgBytes: msg, IngressLinkId: ingressLinkId} } func (w *WshRpc) RecvRpcMessage() ([]byte, bool) { @@ -113,18 +118,17 @@ func (w *WshRpc) RecvRpcMessage() ([]byte, bool) { } type RpcMessage struct { - Command string `json:"command,omitempty"` - ReqId string `json:"reqid,omitempty"` - ResId string `json:"resid,omitempty"` - Timeout int64 `json:"timeout,omitempty"` - Route string `json:"route,omitempty"` // to route/forward requests to alternate servers - AuthToken string `json:"authtoken,omitempty"` // needed for routing unauthenticated requests (WshRpcMultiProxy) - Source string `json:"source,omitempty"` // source route id - Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming - Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) - Error string `json:"error,omitempty"` - DataType string `json:"datatype,omitempty"` - Data any `json:"data,omitempty"` + Command string `json:"command,omitempty"` + ReqId string `json:"reqid,omitempty"` + ResId string `json:"resid,omitempty"` + Timeout int64 `json:"timeout,omitempty"` + Route string `json:"route,omitempty"` // to route/forward requests to alternate servers + Source string `json:"source,omitempty"` // source route id + Cont bool `json:"cont,omitempty"` // flag if additional requests/responses are forthcoming + Cancel bool `json:"cancel,omitempty"` // used to cancel a streaming request or response (sent from the side that is not streaming) + Error string `json:"error,omitempty"` + DataType string `json:"datatype,omitempty"` + Data any `json:"data,omitempty"` } func (r *RpcMessage) IsRpcRequest() bool { @@ -201,9 +205,9 @@ func validateServerImpl(serverImpl ServerImpl) { } // closes outputCh when inputCh is closed/done -func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl, debugName string) *WshRpc { +func MakeWshRpcWithChannels(inputCh chan baseds.RpcInputChType, outputCh chan []byte, rpcCtx wshrpc.RpcContext, serverImpl ServerImpl, debugName string) *WshRpc { if inputCh == nil { - inputCh = make(chan []byte, DefaultInputChSize) + inputCh = make(chan baseds.RpcInputChType, DefaultInputChSize) } if outputCh == nil { outputCh = make(chan []byte, DefaultOutputChSize) @@ -226,6 +230,10 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcCont return rtn } +func MakeWshRpc(rpcCtx wshrpc.RpcContext, serverImpl ServerImpl, debugName string) *WshRpc { + return MakeWshRpcWithChannels(nil, nil, rpcCtx, serverImpl, debugName) +} + func (w *WshRpc) GetRpcContext() wshrpc.RpcContext { rtnPtr := w.RpcContext.Load() return *rtnPtr @@ -235,14 +243,6 @@ func (w *WshRpc) SetRpcContext(ctx wshrpc.RpcContext) { w.RpcContext.Store(&ctx) } -func (w *WshRpc) SetAuthToken(token string) { - w.AuthToken = token -} - -func (w *WshRpc) GetAuthToken() string { - return w.AuthToken -} - func (w *WshRpc) registerResponseHandler(reqId string, handler *RpcResponseHandler) { w.Lock.Lock() defer w.Lock.Unlock() @@ -268,9 +268,9 @@ func (w *WshRpc) cancelRequest(reqId string) { } -func (w *WshRpc) handleRequest(req *RpcMessage) { +func (w *WshRpc) handleRequest(req *RpcMessage, ingressLinkId baseds.LinkId) { pprof.Do(context.Background(), pprof.Labels("rpc", req.Command), func(pprofCtx context.Context) { - w.handleRequestInternal(req, pprofCtx) + w.handleRequestInternal(req, ingressLinkId, pprofCtx) }) } @@ -286,7 +286,7 @@ func (w *WshRpc) handleEventRecv(req *RpcMessage) { w.EventListener.RecvEvent(&waveEvent) } -func (w *WshRpc) handleRequestInternal(req *RpcMessage, pprofCtx context.Context) { +func (w *WshRpc) handleRequestInternal(req *RpcMessage, ingressLinkId baseds.LinkId, pprofCtx context.Context) { if req.Command == wshrpc.Command_EventRecv { w.handleEventRecv(req) return @@ -306,6 +306,7 @@ func (w *WshRpc) handleRequestInternal(req *RpcMessage, pprofCtx context.Context command: req.Command, commandData: req.Data, source: req.Source, + ingressLinkId: ingressLinkId, done: &atomic.Bool{}, canceled: &atomic.Bool{}, contextCancelFn: &atomic.Pointer[context.CancelFunc]{}, @@ -347,17 +348,17 @@ func (w *WshRpc) runServer() { }() outer: for { - var msgBytes []byte + var inputVal baseds.RpcInputChType var inputChMore bool var resIdTimeout string select { - case msgBytes, inputChMore = <-w.InputCh: + case inputVal, inputChMore = <-w.InputCh: if !inputChMore { break outer } if w.Debug { - log.Printf("[%s] received message: %s\n", w.DebugName, string(msgBytes)) + log.Printf("[%s] received message: %s\n", w.DebugName, string(inputVal.MsgBytes)) } case resIdTimeout = <-w.CtxDoneCh: if w.Debug { @@ -368,7 +369,7 @@ outer: } var msg RpcMessage - err := json.Unmarshal(msgBytes, &msg) + err := json.Unmarshal(inputVal.MsgBytes, &msg) if err != nil { log.Printf("wshrpc received bad message: %v\n", err) continue @@ -380,11 +381,12 @@ outer: continue } if msg.IsRpcRequest() { + ingressLinkId := inputVal.IngressLinkId go func() { defer func() { panichandler.PanicHandler("handleRequest:goroutine", recover()) }() - w.handleRequest(&msg) + w.handleRequest(&msg, ingressLinkId) }() } else { w.sendRespWithBlockMessage(msg) @@ -509,9 +511,8 @@ func (handler *RpcRequestHandler) SendCancel(ctx context.Context) error { panichandler.PanicHandler("SendCancel", recover()) }() msg := &RpcMessage{ - Cancel: true, - ReqId: handler.reqId, - AuthToken: handler.w.GetAuthToken(), + Cancel: true, + ReqId: handler.reqId, } barr, _ := json.Marshal(msg) // will never fail select { @@ -580,6 +581,7 @@ type RpcResponseHandler struct { command string commandData any rpcCtx wshrpc.RpcContext + ingressLinkId baseds.LinkId canceled *atomic.Bool // canceled by requestor done *atomic.Bool } @@ -604,6 +606,10 @@ func (handler *RpcResponseHandler) GetSource() string { return handler.source } +func (handler *RpcResponseHandler) GetIngressLinkId() baseds.LinkId { + return handler.ingressLinkId +} + func (handler *RpcResponseHandler) NeedsResponse() bool { return handler.reqId != "" } @@ -614,8 +620,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) { Data: wshrpc.CommandMessageData{ Message: msg, }, - AuthToken: handler.w.GetAuthToken(), - Route: handler.source, // send back to source + Route: handler.source, // send back to source } msgBytes, _ := json.Marshal(rpcMsg) // will never fail select { @@ -638,10 +643,9 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error { return nil } msg := &RpcMessage{ - ResId: handler.reqId, - Data: data, - Cont: !done, - AuthToken: handler.w.GetAuthToken(), + ResId: handler.reqId, + Data: data, + Cont: !done, } barr, err := json.Marshal(msg) if err != nil { @@ -667,9 +671,8 @@ func (handler *RpcResponseHandler) SendResponseError(err error) { return } msg := &RpcMessage{ - ResId: handler.reqId, - Error: err.Error(), - AuthToken: handler.w.GetAuthToken(), + ResId: handler.reqId, + Error: err.Error(), } barr, _ := json.Marshal(msg) // will never fail select { @@ -736,12 +739,11 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp handler.reqId = uuid.New().String() } req := &RpcMessage{ - Command: command, - ReqId: handler.reqId, - Data: data, - Timeout: timeoutMs, - Route: opts.Route, - AuthToken: w.GetAuthToken(), + Command: command, + ReqId: handler.reqId, + Data: data, + Timeout: timeoutMs, + Route: opts.Route, } barr, err := json.Marshal(req) if err != nil { diff --git a/pkg/wshutil/wshrpcio.go b/pkg/wshutil/wshrpcio.go index 7db864626b..3345bdd9d3 100644 --- a/pkg/wshutil/wshrpcio.go +++ b/pkg/wshutil/wshrpcio.go @@ -7,6 +7,7 @@ import ( "fmt" "io" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/util/utilfn" ) @@ -15,9 +16,9 @@ import ( // * stream (json lines) // * websocket (json packets) -func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error { +func AdaptStreamToMsgCh(input io.Reader, output chan baseds.RpcInputChType) error { return utilfn.StreamToLines(input, func(line []byte) { - output <- line + output <- baseds.RpcInputChType{MsgBytes: line} }) } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index c27af7b6f8..28e3db8a77 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -18,15 +18,15 @@ import ( "sync" "sync/atomic" "syscall" - "time" "github.com/golang-jwt/jwt/v5" - "github.com/google/uuid" + "github.com/wavetermdev/waveterm/pkg/baseds" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/util/packetparser" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wavejwt" "github.com/wavetermdev/waveterm/pkg/wshrpc" "golang.org/x/term" ) @@ -202,10 +202,10 @@ func RestoreTermState() { // returns (wshRpc, wrappedStdin) func SetupTerminalRpcClient(serverImpl ServerImpl, debugStr string) (*WshRpc, io.Reader) { - messageCh := make(chan []byte, DefaultInputChSize) + messageCh := make(chan baseds.RpcInputChType, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) ptyBuf := MakePtyBuffer(WaveServerOSCPrefix, os.Stdin, messageCh) - rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) + rpcClient := MakeWshRpcWithChannels(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) go func() { defer func() { panichandler.PanicHandler("SetupTerminalRpcClient", recover()) @@ -217,17 +217,16 @@ func SetupTerminalRpcClient(serverImpl ServerImpl, debugStr string) (*WshRpc, io continue } os.Stdout.Write(barr) - os.Stdout.Write([]byte{'\n'}) } }() return rpcClient, ptyBuf } func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl, debugStr string) (*WshRpc, chan []byte) { - messageCh := make(chan []byte, DefaultInputChSize) + messageCh := make(chan baseds.RpcInputChType, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) rawCh := make(chan []byte, DefaultOutputChSize) - rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) + rpcClient := MakeWshRpcWithChannels(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) go packetparser.Parse(input, messageCh, rawCh) go func() { defer func() { @@ -241,7 +240,7 @@ func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerIm } func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl, debugStr string) (*WshRpc, chan error, error) { - inputCh := make(chan []byte, DefaultInputChSize) + inputCh := make(chan baseds.RpcInputChType, DefaultInputChSize) outputCh := make(chan []byte, DefaultOutputChSize) writeErrCh := make(chan error, 1) go func() { @@ -262,7 +261,7 @@ func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl, debugStr string) ( defer conn.Close() AdaptStreamToMsgCh(conn, inputCh) }() - rtn := MakeWshRpc(inputCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) + rtn := MakeWshRpcWithChannels(inputCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr) return rtn, writeErrCh, nil } @@ -275,6 +274,7 @@ func tryTcpSocket(sockName string) (net.Conn, error) { } func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl, debugName string) (*WshRpc, error) { + sockName = wavebase.ExpandHomeDirSafe(sockName) conn, tcpErr := tryTcpSocket(sockName) var unixErr error if tcpErr != nil { @@ -297,86 +297,41 @@ func SetupDomainSocketRpcClient(sockName string, serverImpl ServerImpl, debugNam return rtn, err } -func MakeClientJWTToken(rpcCtx wshrpc.RpcContext, sockName string) (string, error) { - claims := jwt.MapClaims{} - claims["iat"] = time.Now().Unix() - claims["iss"] = "waveterm" - claims["sock"] = sockName - claims["exp"] = time.Now().Add(time.Hour * 24 * 365).Unix() - if rpcCtx.BlockId != "" { - claims["blockid"] = rpcCtx.BlockId - } - if rpcCtx.TabId != "" { - claims["tabid"] = rpcCtx.TabId - } - if rpcCtx.Conn != "" { - claims["conn"] = rpcCtx.Conn - } - if rpcCtx.ClientType != "" { - claims["ctype"] = rpcCtx.ClientType +func MakeClientJWTToken(rpcCtx wshrpc.RpcContext) (string, error) { + if wavebase.IsDevMode() { + if rpcCtx.IsRouter && rpcCtx.RouteId != "" { + panic("Invalid RpcCtx, router w/ routeid") + } + if !rpcCtx.IsRouter && rpcCtx.RouteId == "" { + panic("Invalid RpcCtx, no routeid") + } } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := token.SignedString([]byte(wavebase.JwtSecret)) - if err != nil { - return "", fmt.Errorf("error signing token: %w", err) + claims := &wavejwt.WaveJwtClaims{ + Sock: rpcCtx.SockName, + RouteId: rpcCtx.RouteId, + BlockId: rpcCtx.BlockId, + Conn: rpcCtx.Conn, + Router: rpcCtx.IsRouter, } - return tokenStr, nil + return wavejwt.Sign(claims) } -func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, error) { - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, err := parser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - return []byte(wavebase.JwtSecret), nil - }) - if err != nil { - return nil, fmt.Errorf("error parsing token: %w", err) - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("error getting claims from token") - } - // validate "exp" claim - if exp, ok := claims["exp"].(float64); ok { - if int64(exp) < time.Now().Unix() { - return nil, fmt.Errorf("token has expired") - } - } else { - return nil, fmt.Errorf("exp claim is missing or invalid") +func claimsToRpcCtx(claims *wavejwt.WaveJwtClaims) *wshrpc.RpcContext { + return &wshrpc.RpcContext{ + SockName: claims.Sock, + RouteId: claims.RouteId, + BlockId: claims.BlockId, + Conn: claims.Conn, + IsRouter: claims.Router, } - // validate "iss" claim - if iss, ok := claims["iss"].(string); ok { - if iss != "waveterm" { - return nil, fmt.Errorf("unexpected issuer: %s", iss) - } - } else { - return nil, fmt.Errorf("iss claim is missing or invalid") - } - return mapClaimsToRpcContext(claims), nil } -func mapClaimsToRpcContext(claims jwt.MapClaims) *wshrpc.RpcContext { - rpcCtx := &wshrpc.RpcContext{} - if claims["blockid"] != nil { - if blockId, ok := claims["blockid"].(string); ok { - rpcCtx.BlockId = blockId - } - } - if claims["tabid"] != nil { - if tabId, ok := claims["tabid"].(string); ok { - rpcCtx.TabId = tabId - } - } - if claims["conn"] != nil { - if conn, ok := claims["conn"].(string); ok { - rpcCtx.Conn = conn - } - } - if claims["ctype"] != nil { - if ctype, ok := claims["ctype"].(string); ok { - rpcCtx.ClientType = ctype - } +func ValidateAndExtractRpcContextFromToken(tokenStr string) (*wshrpc.RpcContext, error) { + claims, err := wavejwt.ValidateAndExtract(tokenStr) + if err != nil { + return nil, err } - return rpcCtx + return claimsToRpcCtx(claims), nil } func RunWshRpcOverListener(listener net.Listener) { @@ -395,26 +350,6 @@ func RunWshRpcOverListener(listener net.Listener) { } } -func MakeRouteIdFromCtx(rpcCtx *wshrpc.RpcContext) (string, error) { - if rpcCtx.ClientType != "" { - if rpcCtx.ClientType == wshrpc.ClientType_ConnServer { - if rpcCtx.Conn != "" { - return MakeConnectionRouteId(rpcCtx.Conn), nil - } - return "", fmt.Errorf("invalid connserver connection, no conn id") - } - if rpcCtx.ClientType == wshrpc.ClientType_BlockController { - if rpcCtx.BlockId != "" { - return MakeControllerRouteId(rpcCtx.BlockId), nil - } - return "", fmt.Errorf("invalid block controller connection, no block id") - } - return "", fmt.Errorf("invalid client type: %q", rpcCtx.ClientType) - } - procId := uuid.New().String() - return MakeProcRouteId(procId), nil -} - type WriteFlusher interface { Write([]byte) (int, error) Flush() error @@ -422,23 +357,24 @@ type WriteFlusher interface { // blocking, returns if there is an error, or on EOF of input func HandleStdIOClient(logName string, input chan utilfn.LineOutput, output io.Writer) { - proxy := MakeRpcMultiProxy() + proxy := MakeRpcProxy(logName) + linkId := DefaultRouter.RegisterTrustedRouter(proxy) rawCh := make(chan []byte, DefaultInputChSize) - go packetparser.ParseWithLinesChan(input, proxy.FromRemoteRawCh, rawCh) + go func() { + defer func() { + panichandler.PanicHandler("HandleStdIOClient:ParseWithLinesChan", recover()) + }() + packetparser.ParseWithLinesChan(input, proxy.FromRemoteCh, rawCh) + }() doneCh := make(chan struct{}) var doneOnce sync.Once closeDoneCh := func() { doneOnce.Do(func() { close(doneCh) + DefaultRouter.UnregisterLink(linkId) + close(proxy.FromRemoteCh) }) - proxy.DisposeRoutes() } - go func() { - defer func() { - panichandler.PanicHandler("HandleStdIOClient:RunUnauthLoop", recover()) - }() - proxy.RunUnauthLoop() - }() go func() { defer func() { panichandler.PanicHandler("HandleStdIOClient:ToRemoteChLoop", recover()) @@ -468,8 +404,8 @@ func HandleStdIOClient(logName string, input chan utilfn.LineOutput, output io.W } func handleDomainSocketClient(conn net.Conn) { - var routeIdContainer atomic.Pointer[string] - proxy := MakeRpcProxy() + var linkIdContainer atomic.Int32 + proxy := MakeRpcProxy("domain") go func() { defer func() { panichandler.PanicHandler("handleDomainSocketClient:AdaptOutputChToStream", recover()) @@ -488,61 +424,42 @@ func handleDomainSocketClient(conn net.Conn) { conn.Close() close(proxy.FromRemoteCh) close(proxy.ToRemoteCh) - routeIdPtr := routeIdContainer.Load() - if routeIdPtr != nil && *routeIdPtr != "" { - DefaultRouter.UnregisterRoute(*routeIdPtr) + linkId := linkIdContainer.Load() + if linkId != baseds.NoLinkId { + DefaultRouter.UnregisterLink(baseds.LinkId(linkId)) } }() AdaptStreamToMsgCh(conn, proxy.FromRemoteCh) }() - rpcCtx, err := proxy.HandleAuthentication() - if err != nil { - conn.Close() - log.Printf("error handling authentication: %v\n", err) - return - } - // now that we're authenticated, set the ctx and attach to the router - log.Printf("domain socket connection authenticated: %#v\n", rpcCtx) - proxy.SetRpcContext(rpcCtx) - routeId, err := MakeRouteIdFromCtx(rpcCtx) - if err != nil { - conn.Close() - log.Printf("error making route id: %v\n", err) - return - } - routeIdContainer.Store(&routeId) - DefaultRouter.RegisterRoute(routeId, proxy, true) + linkId := DefaultRouter.RegisterUntrustedLink(proxy) + linkIdContainer.Store(int32(linkId)) } // only for use on client func ExtractUnverifiedRpcContext(tokenStr string) (*wshrpc.RpcContext, error) { - // this happens on the client who does not have access to the secret key - // we want to read the claims without validating the signature - token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, &wavejwt.WaveJwtClaims{}) if err != nil { return nil, fmt.Errorf("error parsing token: %w", err) } - claims, ok := token.Claims.(jwt.MapClaims) + claims, ok := token.Claims.(*wavejwt.WaveJwtClaims) if !ok { return nil, fmt.Errorf("error getting claims from token") } - return mapClaimsToRpcContext(claims), nil + return claimsToRpcCtx(claims), nil } // only for use on client func ExtractUnverifiedSocketName(tokenStr string) (string, error) { - // this happens on the client who does not have access to the secret key - // we want to read the claims without validating the signature - token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, jwt.MapClaims{}) + token, _, err := new(jwt.Parser).ParseUnverified(tokenStr, &wavejwt.WaveJwtClaims{}) if err != nil { return "", fmt.Errorf("error parsing token: %w", err) } - claims, ok := token.Claims.(jwt.MapClaims) + claims, ok := token.Claims.(*wavejwt.WaveJwtClaims) if !ok { return "", fmt.Errorf("error getting claims from token") } - sockName, ok := claims["sock"].(string) - if !ok { + sockName := claims.Sock + if sockName == "" { return "", fmt.Errorf("sock claim is missing or invalid") } sockName = wavebase.ExpandHomeDirSafe(sockName) diff --git a/pkg/wsl/wsl-unix.go b/pkg/wsl/wsl-unix.go index feb8e8e53e..eba2ff9696 100644 --- a/pkg/wsl/wsl-unix.go +++ b/pkg/wsl/wsl-unix.go @@ -63,7 +63,7 @@ func (c *WslCmd) SetStdout(stdout io.Writer) { } func (c *WslCmd) SetStderr(stderr io.Writer) { - c.Stdout = stderr + c.Stderr = stderr } func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) { diff --git a/pkg/wsl/wsl-win.go b/pkg/wsl/wsl-win.go index fb3be424aa..64a60a66d8 100644 --- a/pkg/wsl/wsl-win.go +++ b/pkg/wsl/wsl-win.go @@ -102,7 +102,7 @@ func (c *WslCmd) SetStdout(stdout io.Writer) { } func (c *WslCmd) SetStderr(stderr io.Writer) { - c.c.Stdout = stderr + c.c.Stderr = stderr } func GetDistroCmd(ctx context.Context, wslDistroName string, cmd string) (*WslCmd, error) { diff --git a/pkg/wslconn/wslconn.go b/pkg/wslconn/wslconn.go index 7fe6594907..9cd5b60a0b 100644 --- a/pkg/wslconn/wslconn.go +++ b/pkg/wslconn/wslconn.go @@ -69,7 +69,7 @@ type WslConn struct { var ConnServerCmdTemplate = strings.TrimSpace( strings.Join([]string{ "%s version 2> /dev/null || (echo -n \"not-installed \"; uname -sm);", - "exec %s connserver --router", + "exec %s connserver --router --conn %s %s", }, "\n")) func GetAllConnStatus() []wshrpc.ConnStatus { @@ -259,15 +259,6 @@ func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo } client := conn.GetClient() wshPath := conn.getWshPath() - rpcCtx := wshrpc.RpcContext{ - ClientType: wshrpc.ClientType_ConnServer, - Conn: conn.GetName(), - } - sockName := conn.GetDomainSocketName() - jwtToken, err := wshutil.MakeClientJWTToken(rpcCtx, sockName) - if err != nil { - return false, "", "", fmt.Errorf("unable to create jwt token for conn controller: %w", err) - } conn.Infof(ctx, "WSL-NEWSESSION (StartConnServer)\n") connServerCtx, cancelFn := context.WithCancel(context.Background()) conn.WithLock(func() { @@ -276,7 +267,11 @@ func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo } conn.cancelFn = cancelFn }) - cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath) + devFlag := "" + if wavebase.IsDevMode() { + devFlag = "--dev" + } + cmdStr := fmt.Sprintf(ConnServerCmdTemplate, wshPath, wshPath, shellutil.HardQuote(conn.GetName()), devFlag) shWrappedCmdStr := fmt.Sprintf("sh -c %s", shellutil.HardQuote(cmdStr)) cmd := client.WslCommand(connServerCtx, shWrappedCmdStr) pipeRead, pipeWrite := io.Pipe() @@ -286,7 +281,7 @@ func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo cmd.SetStdin(inputPipeRead) log.Printf("starting conn controller: %q\n", cmdStr) blocklogger.Debugf(ctx, "[conndebug] wrapped command:\n%s\n", shWrappedCmdStr) - err = cmd.Start() + err := cmd.Start() if err != nil { return false, "", "", fmt.Errorf("unable to start conn controller cmd: %w", err) } @@ -311,21 +306,6 @@ func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo cancelFn() return true, clientVersion, osArchStr, nil } - jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second) - if err != nil { - cancelFn() - return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err) - } - conn.Infof(ctx, "got jwt status line: %s\n", jwtLine) - if strings.TrimSpace(jwtLine) == wavebase.NeedJwtConst { - // write the jwt - conn.Infof(ctx, "writing jwt token to connserver\n") - _, err = fmt.Fprintf(inputPipeWrite, "%s\n", jwtToken) - if err != nil { - cancelFn() - return false, clientVersion, "", fmt.Errorf("failed to write JWT token: %w", err) - } - } conn.WithLock(func() { conn.ConnController = cmd }) @@ -359,7 +339,7 @@ func (conn *WslConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo conn.Infof(ctx, "connserver started, waiting for route to be registered\n") regCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() - err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(rpcCtx.Conn)) + err = wshutil.DefaultRouter.WaitForRegister(regCtx, wshutil.MakeConnectionRouteId(conn.GetName())) if err != nil { return false, clientVersion, "", fmt.Errorf("timeout waiting for connserver to register") } diff --git a/pkg/wstore/wstore_dbops.go b/pkg/wstore/wstore_dbops.go index 2467a09fe5..43f6ff7cfc 100644 --- a/pkg/wstore/wstore_dbops.go +++ b/pkg/wstore/wstore_dbops.go @@ -366,7 +366,6 @@ func DBFindTabForBlockId(ctx context.Context, blockId string) (string, error) { } func DBFindWorkspaceForTabId(ctx context.Context, tabId string) (string, error) { - log.Printf("DBFindWorkspaceForTabId tabId: %s\n", tabId) return WithTxRtn(ctx, func(tx *TxWrap) (string, error) { query := ` WITH variable(value) AS ( @@ -386,7 +385,6 @@ func DBFindWorkspaceForTabId(ctx context.Context, tabId string) (string, error) ); ` wsId := tx.GetString(query, tabId) - log.Printf("DBFindWorkspaceForTabId wsId: %s\n", wsId) return wsId, nil }) }