Skip to content

Commit 13bf7bc

Browse files
feat(go): Add cancelAction and early trace ID to headers in go reflection server (#3885)
Co-authored-by: Alex Pascal <apascal07@gmail.com>
1 parent fff4b43 commit 13bf7bc

File tree

9 files changed

+668
-26
lines changed

9 files changed

+668
-26
lines changed

go/core/action.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw
260260
return r.Result, nil
261261
}
262262

263-
// RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info.
263+
// RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info.
264264
func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) {
265265
i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema)
266266
if err != nil {

go/core/flow.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessa
112112
return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb)
113113
}
114114

115-
// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON.
115+
// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info.
116116
func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) {
117117
return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb)
118118
}

go/core/schemas.config

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ SpanStatus omit
2626
TimeEvent omit
2727
TimeEventAnnotation omit
2828
TraceData omit
29+
SpanStartEvent omit
30+
SpanEndEvent omit
31+
SpanEventBase omit
32+
TraceEvent omit
2933

3034
GenerationCommonConfig.maxOutputTokens type int
3135
GenerationCommonConfig.topK type int

go/core/tracing/tracing.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ type SpanMetadata struct {
173173

174174
// RunInNewSpan runs f on input in a new span with the provided metadata.
175175
// The metadata contains all span configuration including name, type, labels, etc.
176+
// If a telemetry callback was set on the context via WithTelemetryCallback,
177+
// it will be called with the trace ID and span ID as soon as the span is created.
176178
func RunInNewSpan[I, O any](
177179
ctx context.Context,
178180
metadata *SpanMetadata,
@@ -239,6 +241,12 @@ func RunInNewSpan[I, O any](
239241
TraceID: span.SpanContext().TraceID().String(),
240242
SpanID: span.SpanContext().SpanID().String(),
241243
}
244+
245+
// Fire telemetry callback immediately if one was set on the context
246+
if cb := telemetryCallback(ctx); cb != nil {
247+
cb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID)
248+
}
249+
242250
defer span.End()
243251
defer func() { span.SetAttributes(sm.attributes()...) }()
244252
ctx = spanMetaKey.NewContext(ctx, sm)
@@ -371,6 +379,20 @@ func (sm *spanMetadata) attributes() []attribute.KeyValue {
371379
// spanMetaKey is for storing spanMetadatas in a context.
372380
var spanMetaKey = base.NewContextKey[*spanMetadata]()
373381

382+
// telemetryCbKey is the context key for telemetry callbacks.
383+
var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]()
384+
385+
// WithTelemetryCallback returns a context with the telemetry callback attached.
386+
// Used by the reflection server to pass callbacks to actions.
387+
func WithTelemetryCallback(ctx context.Context, cb func(traceID, spanID string)) context.Context {
388+
return telemetryCbKey.NewContext(ctx, cb)
389+
}
390+
391+
// telemetryCallback retrieves the telemetry callback from context, or nil if not set.
392+
func telemetryCallback(ctx context.Context) func(traceID, spanID string) {
393+
return telemetryCbKey.FromContext(ctx)
394+
}
395+
374396
// SpanPath returns the path as recorded in the current span metadata.
375397
func SpanPath(ctx context.Context) string {
376398
return spanMetaKey.FromContext(ctx).Path

go/genkit/reflection.go

Lines changed: 195 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package genkit
1919
import (
2020
"context"
2121
"encoding/json"
22+
"errors"
2223
"fmt"
2324
"log/slog"
2425
"net"
@@ -28,6 +29,7 @@ import (
2829
"sort"
2930
"strconv"
3031
"strings"
32+
"sync"
3133
"time"
3234

3335
"github.com/firebase/genkit/go/core"
@@ -52,7 +54,46 @@ type runtimeFileData struct {
5254
// reflectionServer encapsulates everything needed to serve the Reflection API.
5355
type reflectionServer struct {
5456
*http.Server
55-
RuntimeFilePath string // Path to the runtime file that was written at startup.
57+
RuntimeFilePath string // Path to the runtime file that was written at startup.
58+
activeActions *activeActionsMap // Tracks active actions for cancellation support.
59+
}
60+
61+
// activeAction represents an in-flight action that can be cancelled.
62+
type activeAction struct {
63+
cancel context.CancelFunc
64+
startTime time.Time
65+
traceID string
66+
}
67+
68+
// activeActionsMap safely manages active actions.
69+
type activeActionsMap struct {
70+
mu sync.RWMutex
71+
actions map[string]*activeAction
72+
}
73+
74+
func newActiveActionsMap() *activeActionsMap {
75+
return &activeActionsMap{
76+
actions: make(map[string]*activeAction),
77+
}
78+
}
79+
80+
func (m *activeActionsMap) Set(traceID string, action *activeAction) {
81+
m.mu.Lock()
82+
defer m.mu.Unlock()
83+
m.actions[traceID] = action
84+
}
85+
86+
func (m *activeActionsMap) Get(traceID string) (*activeAction, bool) {
87+
m.mu.RLock()
88+
defer m.mu.RUnlock()
89+
action, ok := m.actions[traceID]
90+
return action, ok
91+
}
92+
93+
func (m *activeActionsMap) Delete(traceID string) {
94+
m.mu.Lock()
95+
defer m.mu.Unlock()
96+
delete(m.actions, traceID)
5697
}
5798

5899
func (s *reflectionServer) runtimeID() string {
@@ -102,6 +143,7 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s
102143
Server: &http.Server{
103144
Addr: addr,
104145
},
146+
activeActions: newActiveActionsMap(),
105147
}
106148
s.Handler = serveMux(g, s)
107149

@@ -258,8 +300,9 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux {
258300
w.WriteHeader(http.StatusOK)
259301
})
260302
mux.HandleFunc("GET /api/actions", wrapReflectionHandler(handleListActions(g)))
261-
mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g)))
303+
mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions)))
262304
mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify()))
305+
mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions)))
263306
return mux
264307
}
265308

@@ -290,7 +333,7 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error)
290333

291334
// handleRunAction looks up an action by name in the registry, runs it with the
292335
// provided JSON input, and writes back the JSON-marshaled request.
293-
func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) error {
336+
func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error {
294337
return func(w http.ResponseWriter, r *http.Request) error {
295338
ctx := r.Context()
296339

@@ -312,11 +355,54 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err
312355

313356
logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream)
314357

358+
// Create cancellable context for this action
359+
actionCtx, cancel := context.WithCancel(ctx)
360+
defer cancel()
361+
362+
// Track whether headers have been sent
363+
headersSent := false
364+
var callbackTraceID string // Trace ID captured from telemetry callback for early header sending
365+
var mu sync.Mutex
366+
367+
// Set up telemetry callback to capture and send trace ID early
368+
// This is used for BOTH streaming and non-streaming to match JS behavior
369+
telemetryCb := func(tid string, sid string) {
370+
mu.Lock()
371+
defer mu.Unlock()
372+
373+
if !headersSent {
374+
callbackTraceID = tid
375+
376+
// Track active action for cancellation
377+
activeActions.Set(callbackTraceID, &activeAction{
378+
cancel: cancel,
379+
startTime: time.Now(),
380+
traceID: callbackTraceID,
381+
})
382+
383+
// Send headers immediately with trace ID
384+
w.Header().Set("X-Genkit-Trace-Id", callbackTraceID)
385+
w.Header().Set("X-Genkit-Span-Id", sid)
386+
w.Header().Set("X-Genkit-Version", "go/"+internal.Version)
387+
388+
if stream {
389+
w.Header().Set("Content-Type", "text/plain")
390+
w.Header().Set("Transfer-Encoding", "chunked")
391+
} else {
392+
w.Header().Set("Content-Type", "application/json")
393+
}
394+
395+
w.WriteHeader(http.StatusOK)
396+
if f, ok := w.(http.Flusher); ok {
397+
f.Flush()
398+
}
399+
headersSent = true
400+
}
401+
}
402+
403+
// Set up streaming callback if needed
315404
var cb streamingCallback[json.RawMessage]
316405
if stream {
317-
w.Header().Set("Content-Type", "text/plain")
318-
w.Header().Set("Transfer-Encoding", "chunked")
319-
// Stream results are newline-separated JSON.
320406
cb = func(ctx context.Context, msg json.RawMessage) error {
321407
_, err := fmt.Fprintf(w, "%s\n", msg)
322408
if err != nil {
@@ -334,35 +420,119 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err
334420
json.Unmarshal(body.Context, &contextMap)
335421
}
336422

337-
resp, err := runAction(ctx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap)
423+
// Attach telemetry callback to context so action can invoke it when span is created
424+
actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb)
425+
resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap)
426+
427+
// Clean up active action using the trace ID from response
428+
if resp != nil && resp.Telemetry.TraceID != "" {
429+
activeActions.Delete(resp.Telemetry.TraceID)
430+
}
431+
338432
if err != nil {
339-
if stream {
340-
refErr := core.ToReflectionError(err)
341-
refErr.Details.TraceID = &resp.Telemetry.TraceID
342-
reflectErr, err := json.Marshal(refErr)
343-
if err != nil {
344-
return err
433+
// Check if context was cancelled
434+
if errors.Is(err, context.Canceled) {
435+
// Use gRPC CANCELLED code (1) in JSON body to match TypeScript behavior
436+
var traceIDPtr *string
437+
if resp != nil && resp.Telemetry.TraceID != "" {
438+
traceIDPtr = &resp.Telemetry.TraceID
439+
}
440+
errResp := errorResponse{
441+
Error: core.ReflectionError{
442+
Code: core.CodeCancelled, // gRPC CANCELLED = 1
443+
Message: "Action was cancelled",
444+
Details: &core.ReflectionErrorDetails{
445+
TraceID: traceIDPtr,
446+
},
447+
},
345448
}
346449

347-
_, err = fmt.Fprintf(w, "{\"error\": %s }", reflectErr)
348-
if err != nil {
349-
return err
450+
if stream {
451+
// For streaming, write error as final chunk
452+
json.NewEncoder(w).Encode(errResp)
453+
} else {
454+
// For non-streaming, return error response
455+
if !headersSent {
456+
w.WriteHeader(http.StatusOK) // Match TS: response.status(200).json(...)
457+
}
458+
json.NewEncoder(w).Encode(errResp)
350459
}
460+
return nil
461+
}
351462

352-
if f, ok := w.(http.Flusher); ok {
353-
f.Flush()
463+
// Handle other errors
464+
if stream {
465+
refErr := core.ToReflectionError(err)
466+
if resp != nil && resp.Telemetry.TraceID != "" {
467+
refErr.Details.TraceID = &resp.Telemetry.TraceID
354468
}
469+
470+
json.NewEncoder(w).Encode(errorResponse{Error: refErr})
355471
return nil
356472
}
473+
474+
// Non-streaming error
357475
errorResponse := core.ToReflectionError(err)
358-
if resp != nil {
476+
if resp != nil && resp.Telemetry.TraceID != "" {
359477
errorResponse.Details.TraceID = &resp.Telemetry.TraceID
360478
}
361-
w.WriteHeader(errorResponse.Code)
479+
480+
if !headersSent {
481+
w.WriteHeader(errorResponse.Code)
482+
}
362483
return writeJSON(ctx, w, errorResponse)
363484
}
364485

365-
return writeJSON(ctx, w, resp)
486+
// Success case
487+
if stream {
488+
// For streaming, write the final chunk with result and telemetry
489+
// This matches JS: response.write(JSON.stringify({result, telemetry}))
490+
finalResponse := runActionResponse{
491+
Result: resp.Result,
492+
Telemetry: telemetry{TraceID: resp.Telemetry.TraceID},
493+
}
494+
json.NewEncoder(w).Encode(finalResponse)
495+
} else {
496+
// For non-streaming, headers were already sent via telemetry callback
497+
// Response already includes telemetry.traceId in body
498+
return writeJSON(ctx, w, resp)
499+
}
500+
501+
return nil
502+
}
503+
}
504+
505+
// handleCancelAction cancels an in-flight action by trace ID.
506+
func handleCancelAction(activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error {
507+
return func(w http.ResponseWriter, r *http.Request) error {
508+
var body struct {
509+
TraceID string `json:"traceId"`
510+
}
511+
512+
defer r.Body.Close()
513+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
514+
return core.NewError(core.INVALID_ARGUMENT, err.Error())
515+
}
516+
517+
if body.TraceID == "" {
518+
return core.NewError(core.INVALID_ARGUMENT, "traceId is required")
519+
}
520+
521+
action, exists := activeActions.Get(body.TraceID)
522+
if !exists {
523+
w.WriteHeader(http.StatusNotFound)
524+
return writeJSON(r.Context(), w, map[string]string{
525+
"error": "Action not found or already completed",
526+
})
527+
}
528+
529+
// Cancel the action's context
530+
action.cancel()
531+
activeActions.Delete(body.TraceID)
532+
533+
return writeJSON(r.Context(), w, map[string]string{
534+
"message": "Action cancelled",
535+
})
366536
}
367537
}
368538

@@ -462,6 +632,10 @@ type telemetry struct {
462632
TraceID string `json:"traceId"`
463633
}
464634

635+
type errorResponse struct {
636+
Error core.ReflectionError `json:"error"`
637+
}
638+
465639
func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) {
466640
action := g.reg.ResolveAction(key)
467641
if action == nil {

0 commit comments

Comments
 (0)