@@ -19,6 +19,7 @@ package genkit
1919import (
2020 "context"
2121 "encoding/json"
22+ "errors"
2223 "fmt"
2324 "log/slog"
2425 "net"
@@ -28,6 +29,7 @@ import (
2829 "sort"
2930 "strconv"
3031 "strings"
32+ "sync"
3133 "time"
3234
3335 "github.com/firebase/genkit/go/core"
@@ -52,7 +54,46 @@ type runtimeFileData struct {
5254// reflectionServer encapsulates everything needed to serve the Reflection API.
5355type reflectionServer struct {
5456 * http.Server
55- RuntimeFilePath string // Path to the runtime file that was written at startup.
57+ RuntimeFilePath string // Path to the runtime file that was written at startup.
58+ activeActions * activeActionsMap // Tracks active actions for cancellation support.
59+ }
60+
61+ // activeAction represents an in-flight action that can be cancelled.
62+ type activeAction struct {
63+ cancel context.CancelFunc
64+ startTime time.Time
65+ traceID string
66+ }
67+
68+ // activeActionsMap safely manages active actions.
69+ type activeActionsMap struct {
70+ mu sync.RWMutex
71+ actions map [string ]* activeAction
72+ }
73+
74+ func newActiveActionsMap () * activeActionsMap {
75+ return & activeActionsMap {
76+ actions : make (map [string ]* activeAction ),
77+ }
78+ }
79+
80+ func (m * activeActionsMap ) Set (traceID string , action * activeAction ) {
81+ m .mu .Lock ()
82+ defer m .mu .Unlock ()
83+ m .actions [traceID ] = action
84+ }
85+
86+ func (m * activeActionsMap ) Get (traceID string ) (* activeAction , bool ) {
87+ m .mu .RLock ()
88+ defer m .mu .RUnlock ()
89+ action , ok := m .actions [traceID ]
90+ return action , ok
91+ }
92+
93+ func (m * activeActionsMap ) Delete (traceID string ) {
94+ m .mu .Lock ()
95+ defer m .mu .Unlock ()
96+ delete (m .actions , traceID )
5697}
5798
5899func (s * reflectionServer ) runtimeID () string {
@@ -102,6 +143,7 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s
102143 Server : & http.Server {
103144 Addr : addr ,
104145 },
146+ activeActions : newActiveActionsMap (),
105147 }
106148 s .Handler = serveMux (g , s )
107149
@@ -258,8 +300,9 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux {
258300 w .WriteHeader (http .StatusOK )
259301 })
260302 mux .HandleFunc ("GET /api/actions" , wrapReflectionHandler (handleListActions (g )))
261- mux .HandleFunc ("POST /api/runAction" , wrapReflectionHandler (handleRunAction (g )))
303+ mux .HandleFunc ("POST /api/runAction" , wrapReflectionHandler (handleRunAction (g , s . activeActions )))
262304 mux .HandleFunc ("POST /api/notify" , wrapReflectionHandler (handleNotify ()))
305+ mux .HandleFunc ("POST /api/cancelAction" , wrapReflectionHandler (handleCancelAction (s .activeActions )))
263306 return mux
264307}
265308
@@ -290,7 +333,7 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error)
290333
291334// handleRunAction looks up an action by name in the registry, runs it with the
292335// provided JSON input, and writes back the JSON-marshaled request.
293- func handleRunAction (g * Genkit ) func (w http.ResponseWriter , r * http.Request ) error {
336+ func handleRunAction (g * Genkit , activeActions * activeActionsMap ) func (w http.ResponseWriter , r * http.Request ) error {
294337 return func (w http.ResponseWriter , r * http.Request ) error {
295338 ctx := r .Context ()
296339
@@ -312,11 +355,54 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err
312355
313356 logger .FromContext (ctx ).Debug ("running action" , "key" , body .Key , "stream" , stream )
314357
358+ // Create cancellable context for this action
359+ actionCtx , cancel := context .WithCancel (ctx )
360+ defer cancel ()
361+
362+ // Track whether headers have been sent
363+ headersSent := false
364+ var callbackTraceID string // Trace ID captured from telemetry callback for early header sending
365+ var mu sync.Mutex
366+
367+ // Set up telemetry callback to capture and send trace ID early
368+ // This is used for BOTH streaming and non-streaming to match JS behavior
369+ telemetryCb := func (tid string , sid string ) {
370+ mu .Lock ()
371+ defer mu .Unlock ()
372+
373+ if ! headersSent {
374+ callbackTraceID = tid
375+
376+ // Track active action for cancellation
377+ activeActions .Set (callbackTraceID , & activeAction {
378+ cancel : cancel ,
379+ startTime : time .Now (),
380+ traceID : callbackTraceID ,
381+ })
382+
383+ // Send headers immediately with trace ID
384+ w .Header ().Set ("X-Genkit-Trace-Id" , callbackTraceID )
385+ w .Header ().Set ("X-Genkit-Span-Id" , sid )
386+ w .Header ().Set ("X-Genkit-Version" , "go/" + internal .Version )
387+
388+ if stream {
389+ w .Header ().Set ("Content-Type" , "text/plain" )
390+ w .Header ().Set ("Transfer-Encoding" , "chunked" )
391+ } else {
392+ w .Header ().Set ("Content-Type" , "application/json" )
393+ }
394+
395+ w .WriteHeader (http .StatusOK )
396+ if f , ok := w .(http.Flusher ); ok {
397+ f .Flush ()
398+ }
399+ headersSent = true
400+ }
401+ }
402+
403+ // Set up streaming callback if needed
315404 var cb streamingCallback [json.RawMessage ]
316405 if stream {
317- w .Header ().Set ("Content-Type" , "text/plain" )
318- w .Header ().Set ("Transfer-Encoding" , "chunked" )
319- // Stream results are newline-separated JSON.
320406 cb = func (ctx context.Context , msg json.RawMessage ) error {
321407 _ , err := fmt .Fprintf (w , "%s\n " , msg )
322408 if err != nil {
@@ -334,35 +420,119 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err
334420 json .Unmarshal (body .Context , & contextMap )
335421 }
336422
337- resp , err := runAction (ctx , g , body .Key , body .Input , body .TelemetryLabels , cb , contextMap )
423+ // Attach telemetry callback to context so action can invoke it when span is created
424+ actionCtx = tracing .WithTelemetryCallback (actionCtx , telemetryCb )
425+ resp , err := runAction (actionCtx , g , body .Key , body .Input , body .TelemetryLabels , cb , contextMap )
426+
427+ // Clean up active action using the trace ID from response
428+ if resp != nil && resp .Telemetry .TraceID != "" {
429+ activeActions .Delete (resp .Telemetry .TraceID )
430+ }
431+
338432 if err != nil {
339- if stream {
340- refErr := core .ToReflectionError (err )
341- refErr .Details .TraceID = & resp .Telemetry .TraceID
342- reflectErr , err := json .Marshal (refErr )
343- if err != nil {
344- return err
433+ // Check if context was cancelled
434+ if errors .Is (err , context .Canceled ) {
435+ // Use gRPC CANCELLED code (1) in JSON body to match TypeScript behavior
436+ var traceIDPtr * string
437+ if resp != nil && resp .Telemetry .TraceID != "" {
438+ traceIDPtr = & resp .Telemetry .TraceID
439+ }
440+ errResp := errorResponse {
441+ Error : core.ReflectionError {
442+ Code : core .CodeCancelled , // gRPC CANCELLED = 1
443+ Message : "Action was cancelled" ,
444+ Details : & core.ReflectionErrorDetails {
445+ TraceID : traceIDPtr ,
446+ },
447+ },
345448 }
346449
347- _ , err = fmt .Fprintf (w , "{\" error\" : %s }" , reflectErr )
348- if err != nil {
349- return err
450+ if stream {
451+ // For streaming, write error as final chunk
452+ json .NewEncoder (w ).Encode (errResp )
453+ } else {
454+ // For non-streaming, return error response
455+ if ! headersSent {
456+ w .WriteHeader (http .StatusOK ) // Match TS: response.status(200).json(...)
457+ }
458+ json .NewEncoder (w ).Encode (errResp )
350459 }
460+ return nil
461+ }
351462
352- if f , ok := w .(http.Flusher ); ok {
353- f .Flush ()
463+ // Handle other errors
464+ if stream {
465+ refErr := core .ToReflectionError (err )
466+ if resp != nil && resp .Telemetry .TraceID != "" {
467+ refErr .Details .TraceID = & resp .Telemetry .TraceID
354468 }
469+
470+ json .NewEncoder (w ).Encode (errorResponse {Error : refErr })
355471 return nil
356472 }
473+
474+ // Non-streaming error
357475 errorResponse := core .ToReflectionError (err )
358- if resp != nil {
476+ if resp != nil && resp . Telemetry . TraceID != "" {
359477 errorResponse .Details .TraceID = & resp .Telemetry .TraceID
360478 }
361- w .WriteHeader (errorResponse .Code )
479+
480+ if ! headersSent {
481+ w .WriteHeader (errorResponse .Code )
482+ }
362483 return writeJSON (ctx , w , errorResponse )
363484 }
364485
365- return writeJSON (ctx , w , resp )
486+ // Success case
487+ if stream {
488+ // For streaming, write the final chunk with result and telemetry
489+ // This matches JS: response.write(JSON.stringify({result, telemetry}))
490+ finalResponse := runActionResponse {
491+ Result : resp .Result ,
492+ Telemetry : telemetry {TraceID : resp .Telemetry .TraceID },
493+ }
494+ json .NewEncoder (w ).Encode (finalResponse )
495+ } else {
496+ // For non-streaming, headers were already sent via telemetry callback
497+ // Response already includes telemetry.traceId in body
498+ return writeJSON (ctx , w , resp )
499+ }
500+
501+ return nil
502+ }
503+ }
504+
505+ // handleCancelAction cancels an in-flight action by trace ID.
506+ func handleCancelAction (activeActions * activeActionsMap ) func (w http.ResponseWriter , r * http.Request ) error {
507+ return func (w http.ResponseWriter , r * http.Request ) error {
508+ var body struct {
509+ TraceID string `json:"traceId"`
510+ }
511+
512+ defer r .Body .Close ()
513+ if err := json .NewDecoder (r .Body ).Decode (& body ); err != nil {
514+ return core .NewError (core .INVALID_ARGUMENT , err .Error ())
515+ }
516+
517+ if body .TraceID == "" {
518+ return core .NewError (core .INVALID_ARGUMENT , "traceId is required" )
519+ }
520+
521+ action , exists := activeActions .Get (body .TraceID )
522+ if ! exists {
523+ w .WriteHeader (http .StatusNotFound )
524+ return writeJSON (r .Context (), w , map [string ]string {
525+ "error" : "Action not found or already completed" ,
526+ })
527+ }
528+
529+ // Cancel the action's context
530+ action .cancel ()
531+ activeActions .Delete (body .TraceID )
532+
533+ return writeJSON (r .Context (), w , map [string ]string {
534+ "message" : "Action cancelled" ,
535+ })
366536 }
367537}
368538
@@ -462,6 +632,10 @@ type telemetry struct {
462632 TraceID string `json:"traceId"`
463633}
464634
635+ type errorResponse struct {
636+ Error core.ReflectionError `json:"error"`
637+ }
638+
465639func runAction (ctx context.Context , g * Genkit , key string , input json.RawMessage , telemetryLabels json.RawMessage , cb streamingCallback [json.RawMessage ], runtimeContext map [string ]any ) (* runActionResponse , error ) {
466640 action := g .reg .ResolveAction (key )
467641 if action == nil {
0 commit comments