diff --git a/go/core/action.go b/go/core/action.go index 45d71e3177..9acfc03008 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -260,7 +260,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw return r.Result, nil } -// RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info. +// RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { diff --git a/go/core/flow.go b/go/core/flow.go index b5311bbbf3..0cd12120f2 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -112,7 +112,7 @@ func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessa return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb) } -// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON. +// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info. func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb) } diff --git a/go/core/schemas.config b/go/core/schemas.config index 7598011f17..1d4ff98001 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -26,6 +26,10 @@ SpanStatus omit TimeEvent omit TimeEventAnnotation omit TraceData omit +SpanStartEvent omit +SpanEndEvent omit +SpanEventBase omit +TraceEvent omit GenerationCommonConfig.maxOutputTokens type int GenerationCommonConfig.topK type int diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index bcd05aac19..727625f75a 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -173,6 +173,8 @@ type SpanMetadata struct { // RunInNewSpan runs f on input in a new span with the provided metadata. // The metadata contains all span configuration including name, type, labels, etc. +// If a telemetry callback was set on the context via WithTelemetryCallback, +// it will be called with the trace ID and span ID as soon as the span is created. func RunInNewSpan[I, O any]( ctx context.Context, metadata *SpanMetadata, @@ -239,6 +241,12 @@ func RunInNewSpan[I, O any]( TraceID: span.SpanContext().TraceID().String(), SpanID: span.SpanContext().SpanID().String(), } + + // Fire telemetry callback immediately if one was set on the context + if cb := telemetryCallback(ctx); cb != nil { + cb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID) + } + defer span.End() defer func() { span.SetAttributes(sm.attributes()...) }() ctx = spanMetaKey.NewContext(ctx, sm) @@ -371,6 +379,20 @@ func (sm *spanMetadata) attributes() []attribute.KeyValue { // spanMetaKey is for storing spanMetadatas in a context. var spanMetaKey = base.NewContextKey[*spanMetadata]() +// telemetryCbKey is the context key for telemetry callbacks. +var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]() + +// WithTelemetryCallback returns a context with the telemetry callback attached. +// Used by the reflection server to pass callbacks to actions. +func WithTelemetryCallback(ctx context.Context, cb func(traceID, spanID string)) context.Context { + return telemetryCbKey.NewContext(ctx, cb) +} + +// telemetryCallback retrieves the telemetry callback from context, or nil if not set. +func telemetryCallback(ctx context.Context) func(traceID, spanID string) { + return telemetryCbKey.FromContext(ctx) +} + // SpanPath returns the path as recorded in the current span metadata. func SpanPath(ctx context.Context) string { return spanMetaKey.FromContext(ctx).Path diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index bb32c79bff..f0dcbf5578 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -19,6 +19,7 @@ package genkit import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "net" @@ -28,6 +29,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/firebase/genkit/go/core" @@ -52,7 +54,46 @@ type runtimeFileData struct { // reflectionServer encapsulates everything needed to serve the Reflection API. type reflectionServer struct { *http.Server - RuntimeFilePath string // Path to the runtime file that was written at startup. + RuntimeFilePath string // Path to the runtime file that was written at startup. + activeActions *activeActionsMap // Tracks active actions for cancellation support. +} + +// activeAction represents an in-flight action that can be cancelled. +type activeAction struct { + cancel context.CancelFunc + startTime time.Time + traceID string +} + +// activeActionsMap safely manages active actions. +type activeActionsMap struct { + mu sync.RWMutex + actions map[string]*activeAction +} + +func newActiveActionsMap() *activeActionsMap { + return &activeActionsMap{ + actions: make(map[string]*activeAction), + } +} + +func (m *activeActionsMap) Set(traceID string, action *activeAction) { + m.mu.Lock() + defer m.mu.Unlock() + m.actions[traceID] = action +} + +func (m *activeActionsMap) Get(traceID string) (*activeAction, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + action, ok := m.actions[traceID] + return action, ok +} + +func (m *activeActionsMap) Delete(traceID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.actions, traceID) } func (s *reflectionServer) runtimeID() string { @@ -102,6 +143,7 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s Server: &http.Server{ Addr: addr, }, + activeActions: newActiveActionsMap(), } s.Handler = serveMux(g, s) @@ -258,8 +300,9 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { w.WriteHeader(http.StatusOK) }) mux.HandleFunc("GET /api/actions", wrapReflectionHandler(handleListActions(g))) - mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g))) + mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) + mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) return mux } @@ -290,7 +333,7 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error) // handleRunAction looks up an action by name in the registry, runs it with the // provided JSON input, and writes back the JSON-marshaled request. -func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { +func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -312,11 +355,54 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream) + // Create cancellable context for this action + actionCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Track whether headers have been sent + headersSent := false + var callbackTraceID string // Trace ID captured from telemetry callback for early header sending + var mu sync.Mutex + + // Set up telemetry callback to capture and send trace ID early + // This is used for BOTH streaming and non-streaming to match JS behavior + telemetryCb := func(tid string, sid string) { + mu.Lock() + defer mu.Unlock() + + if !headersSent { + callbackTraceID = tid + + // Track active action for cancellation + activeActions.Set(callbackTraceID, &activeAction{ + cancel: cancel, + startTime: time.Now(), + traceID: callbackTraceID, + }) + + // Send headers immediately with trace ID + w.Header().Set("X-Genkit-Trace-Id", callbackTraceID) + w.Header().Set("X-Genkit-Span-Id", sid) + w.Header().Set("X-Genkit-Version", "go/"+internal.Version) + + if stream { + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Transfer-Encoding", "chunked") + } else { + w.Header().Set("Content-Type", "application/json") + } + + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + headersSent = true + } + } + + // Set up streaming callback if needed var cb streamingCallback[json.RawMessage] if stream { - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Transfer-Encoding", "chunked") - // Stream results are newline-separated JSON. cb = func(ctx context.Context, msg json.RawMessage) error { _, err := fmt.Fprintf(w, "%s\n", msg) if err != nil { @@ -334,35 +420,119 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err json.Unmarshal(body.Context, &contextMap) } - resp, err := runAction(ctx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) + // Attach telemetry callback to context so action can invoke it when span is created + actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb) + resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) + + // Clean up active action using the trace ID from response + if resp != nil && resp.Telemetry.TraceID != "" { + activeActions.Delete(resp.Telemetry.TraceID) + } + if err != nil { - if stream { - refErr := core.ToReflectionError(err) - refErr.Details.TraceID = &resp.Telemetry.TraceID - reflectErr, err := json.Marshal(refErr) - if err != nil { - return err + // Check if context was cancelled + if errors.Is(err, context.Canceled) { + // Use gRPC CANCELLED code (1) in JSON body to match TypeScript behavior + var traceIDPtr *string + if resp != nil && resp.Telemetry.TraceID != "" { + traceIDPtr = &resp.Telemetry.TraceID + } + errResp := errorResponse{ + Error: core.ReflectionError{ + Code: core.CodeCancelled, // gRPC CANCELLED = 1 + Message: "Action was cancelled", + Details: &core.ReflectionErrorDetails{ + TraceID: traceIDPtr, + }, + }, } - _, err = fmt.Fprintf(w, "{\"error\": %s }", reflectErr) - if err != nil { - return err + if stream { + // For streaming, write error as final chunk + json.NewEncoder(w).Encode(errResp) + } else { + // For non-streaming, return error response + if !headersSent { + w.WriteHeader(http.StatusOK) // Match TS: response.status(200).json(...) + } + json.NewEncoder(w).Encode(errResp) } + return nil + } - if f, ok := w.(http.Flusher); ok { - f.Flush() + // Handle other errors + if stream { + refErr := core.ToReflectionError(err) + if resp != nil && resp.Telemetry.TraceID != "" { + refErr.Details.TraceID = &resp.Telemetry.TraceID } + + json.NewEncoder(w).Encode(errorResponse{Error: refErr}) return nil } + + // Non-streaming error errorResponse := core.ToReflectionError(err) - if resp != nil { + if resp != nil && resp.Telemetry.TraceID != "" { errorResponse.Details.TraceID = &resp.Telemetry.TraceID } - w.WriteHeader(errorResponse.Code) + + if !headersSent { + w.WriteHeader(errorResponse.Code) + } return writeJSON(ctx, w, errorResponse) } - return writeJSON(ctx, w, resp) + // Success case + if stream { + // For streaming, write the final chunk with result and telemetry + // This matches JS: response.write(JSON.stringify({result, telemetry})) + finalResponse := runActionResponse{ + Result: resp.Result, + Telemetry: telemetry{TraceID: resp.Telemetry.TraceID}, + } + json.NewEncoder(w).Encode(finalResponse) + } else { + // For non-streaming, headers were already sent via telemetry callback + // Response already includes telemetry.traceId in body + return writeJSON(ctx, w, resp) + } + + return nil + } +} + +// handleCancelAction cancels an in-flight action by trace ID. +func handleCancelAction(activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + var body struct { + TraceID string `json:"traceId"` + } + + defer r.Body.Close() + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return core.NewError(core.INVALID_ARGUMENT, err.Error()) + } + + if body.TraceID == "" { + return core.NewError(core.INVALID_ARGUMENT, "traceId is required") + } + + action, exists := activeActions.Get(body.TraceID) + if !exists { + w.WriteHeader(http.StatusNotFound) + return writeJSON(r.Context(), w, map[string]string{ + "error": "Action not found or already completed", + }) + } + + // Cancel the action's context + action.cancel() + activeActions.Delete(body.TraceID) + + return writeJSON(r.Context(), w, map[string]string{ + "message": "Action cancelled", + }) } } @@ -462,6 +632,10 @@ type telemetry struct { TraceID string `json:"traceId"` } +type errorResponse struct { + Error core.ReflectionError `json:"error"` +} + func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := g.reg.ResolveAction(key) if action == nil { diff --git a/go/genkit/reflection_test.go b/go/genkit/reflection_test.go index d47a10a027..7b11914348 100644 --- a/go/genkit/reflection_test.go +++ b/go/genkit/reflection_test.go @@ -21,6 +21,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -91,7 +92,8 @@ func TestServeMux(t *testing.T) { core.DefineAction(g.reg, "test/dec", api.ActionTypeCustom, nil, nil, dec) s := &reflectionServer{ - Server: &http.Server{}, + Server: &http.Server{}, + activeActions: newActiveActionsMap(), } ts := httptest.NewServer(serveMux(g, s)) s.Addr = strings.TrimPrefix(ts.URL, "http://") @@ -290,3 +292,326 @@ func TestServeMux(t *testing.T) { } }) } + +// TestEarlyTraceIDTransmission verifies that trace ID headers are sent BEFORE the action completes. +// +// The key thing we're testing: headers arrive while the action is still running, not after. +// This allows clients to get the trace ID immediately for cancellation or logging. +func TestEarlyTraceIDTransmission(t *testing.T) { + g := Init(context.Background()) + tc := tracing.NewTestOnlyTelemetryClient() + tracing.WriteTelemetryImmediate(tc) + + actionStarted := make(chan struct{}) + actionCanProceed := make(chan struct{}) + + // Action that waits for permission to complete - this lets us check headers while it's running + core.DefineAction(g.reg, "test/slow", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any) (any, error) { + close(actionStarted) // Signal we've started + <-actionCanProceed // Wait for test to say we can finish + return "completed", nil + }) + + s := &reflectionServer{Server: &http.Server{}, activeActions: newActiveActionsMap()} + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + t.Run("headers arrive before body completes", func(t *testing.T) { + // Channel to receive headers as soon as they arrive + type headerResult struct { + traceID string + spanID string + version string + } + gotHeaders := make(chan headerResult) + + go func() { + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction", + strings.NewReader(`{"key":"/custom/test/slow","input":null}`)) + req.Header.Set("Content-Type", "application/json") + + // Do() returns as soon as headers are received (before body is read) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + // Send headers immediately - body isn't done yet! + gotHeaders <- headerResult{ + traceID: resp.Header.Get("X-Genkit-Trace-Id"), + spanID: resp.Header.Get("X-Genkit-Span-Id"), + version: resp.Header.Get("X-Genkit-Version"), + } + + // Now read body (which will block until action completes) + io.ReadAll(resp.Body) + }() + + // Wait for action to start + <-actionStarted + + // Check headers arrived WHILE action is still running + select { + case h := <-gotHeaders: + if h.traceID == "" { + t.Error("Expected X-Genkit-Trace-Id header") + } + if h.spanID == "" { + t.Error("Expected X-Genkit-Span-Id header") + } + if !strings.HasPrefix(h.version, "go/") { + t.Errorf("Expected X-Genkit-Version to start with 'go/', got %q", h.version) + } + t.Logf("Got headers while action running: traceID=%s", h.traceID) + case <-time.After(1 * time.Second): + t.Fatal("Headers did not arrive while action was still running") + } + + // Let action complete + close(actionCanProceed) + }) + + // Backwards compatability + t.Run("trace ID in headers matches body", func(t *testing.T) { + // Reset channels for this subtest + actionStarted = make(chan struct{}) + actionCanProceed = make(chan struct{}) + + // Re-register action for this subtest + core.DefineAction(g.reg, "test/slow2", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any) (any, error) { + close(actionStarted) + <-actionCanProceed + return "completed", nil + }) + + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction", + strings.NewReader(`{"key":"/custom/test/slow2","input":null}`)) + req.Header.Set("Content-Type", "application/json") + + // Start request in background + type result struct { + headerTraceID string + bodyTraceID string + } + done := make(chan result) + + go func() { + resp, err := http.DefaultClient.Do(req) + if err != nil { + done <- result{} + return + } + defer resp.Body.Close() + headerTraceID := resp.Header.Get("X-Genkit-Trace-Id") + + var body map[string]interface{} + json.NewDecoder(resp.Body).Decode(&body) + bodyTraceID := "" + if tel, ok := body["telemetry"].(map[string]interface{}); ok { + bodyTraceID, _ = tel["traceId"].(string) + } + done <- result{headerTraceID, bodyTraceID} + }() + + <-actionStarted + close(actionCanProceed) + + r := <-done + if r.headerTraceID == "" { + t.Error("No trace ID in headers") + } + if r.bodyTraceID == "" { + t.Error("No trace ID in body") + } + if r.headerTraceID != r.bodyTraceID { + t.Errorf("Trace ID mismatch: header=%q, body=%q", r.headerTraceID, r.bodyTraceID) + } + }) +} + +// TestActionCancellation verifies that running actions can be cancelled via /api/cancelAction. +// +// Flow: +// 1. Start a long-running action that sends its trace ID via channel when it starts +// 2. Call POST /api/cancelAction with that trace ID +// 3. Verify: cancel endpoint returns 200, action's ctx.Done() fires, response has error code 1 (gRPC CANCELLED) +func TestActionCancellation(t *testing.T) { + g := Init(context.Background()) + tc := tracing.NewTestOnlyTelemetryClient() + tracing.WriteTelemetryImmediate(tc) + + gotTraceID := make(chan string, 1) + gotCancelled := make(chan struct{}) + + // Long-running action that respects cancellation + core.DefineStreamingAction(g.reg, "test/cancellable", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any, cb func(context.Context, any) error) (any, error) { + // Send trace ID so test can cancel us + gotTraceID <- tracing.SpanTraceInfo(ctx).TraceID + + for i := 0; i < 100; i++ { + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + return nil, fmt.Errorf("expected context.Canceled, got %v", ctx.Err()) + } + close(gotCancelled) + return nil, ctx.Err() + case <-time.After(50 * time.Millisecond): + if cb != nil && i%10 == 0 { + cb(ctx, fmt.Sprintf("progress: %d", i)) + } + } + } + return "completed", nil + }) + + s := &reflectionServer{Server: &http.Server{}, activeActions: newActiveActionsMap()} + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + // Start action in background + actionDone := make(chan string) // receives response body when done + go func() { + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction?stream=true", + strings.NewReader(`{"key":"/custom/test/cancellable","input":null}`)) + req.Header.Set("Content-Type", "application/json") + resp, _ := http.DefaultClient.Do(req) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + actionDone <- string(body) + }() + + // Wait for action to start + traceID := <-gotTraceID + time.Sleep(50 * time.Millisecond) // ensure it's tracked + + // Cancel it + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(fmt.Sprintf(`{"traceId":"%s"}`, traceID))) + cancelReq.Header.Set("Content-Type", "application/json") + cancelResp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer cancelResp.Body.Close() + + if cancelResp.StatusCode != http.StatusOK { + t.Fatalf("Cancel failed with status %d", cancelResp.StatusCode) + } + + // Verify action acknowledged cancellation + select { + case <-gotCancelled: + case <-time.After(1 * time.Second): + t.Fatal("Action did not acknowledge cancellation") + } + + // Verify response indicates cancellation + responseBody := <-actionDone + if !strings.Contains(responseBody, "\"code\":1") { + t.Errorf("Expected error code 1 (gRPC CANCELLED) in response, got: %s", responseBody) + } + if !strings.Contains(responseBody, "Action was cancelled") { + t.Errorf("Expected 'Action was cancelled' message in response, got: %s", responseBody) + } +} + +func TestCancelActionEndpoint(t *testing.T) { + g := Init(context.Background()) + + s := &reflectionServer{ + Server: &http.Server{}, + activeActions: newActiveActionsMap(), + } + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + t.Run("cancel non-existent action", func(t *testing.T) { + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(`{"traceId":"non-existent-trace-id"}`)) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404 for non-existent action, got %d", resp.StatusCode) + } + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + if error, ok := result["error"].(string); !ok || error != "Action not found or already completed" { + t.Errorf("Unexpected error message: %v", result) + } + }) + + t.Run("cancel with missing traceId", func(t *testing.T) { + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(`{}`)) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusInternalServerError { + t.Errorf("Expected 400 or 500 for missing traceId, got %d", resp.StatusCode) + } + }) + + t.Run("cancel active action", func(t *testing.T) { + // Manually add an action to activeActions + testTraceID := "test-trace-id-12345" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.activeActions.Set(testTraceID, &activeAction{ + cancel: cancel, + startTime: time.Now(), + traceID: testTraceID, + }) + + // Send cancel request + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(fmt.Sprintf(`{"traceId":"%s"}`, testTraceID))) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 for successful cancellation, got %d", resp.StatusCode) + } + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + if message, ok := result["message"].(string); !ok || message != "Action cancelled" { + t.Errorf("Expected 'Action cancelled' message, got: %v", result) + } + + // Verify action was removed from activeActions + if action, exists := s.activeActions.Get(testTraceID); exists { + t.Errorf("Action should have been removed from activeActions, but still exists: %v", action) + } + + // Verify context was cancelled + select { + case <-ctx.Done(): + // Good, context was cancelled + default: + t.Error("Context should have been cancelled") + } + }) +} diff --git a/go/internal/cmd/jsonschemagen/jsonschema.go b/go/internal/cmd/jsonschemagen/jsonschema.go index 2ee6930954..0afbde5e74 100644 --- a/go/internal/cmd/jsonschemagen/jsonschema.go +++ b/go/internal/cmd/jsonschemagen/jsonschema.go @@ -33,7 +33,7 @@ type Schema struct { Description string `json:"description,omitempty"` Properties map[string]*Schema `json:"properties,omitempty"` AdditionalProperties *Schema `json:"additionalProperties,omitempty"` - Const bool `json:"const,omitempty"` + Const any `json:"const,omitempty"` Required []string `json:"required,omitempty"` Items *Schema `json:"items,omitempty"` Enum []string `json:"enum,omitempty"` diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index a6f567042b..9b37943d7c 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -41,6 +41,7 @@ import ( "log" "net/http" "strconv" + "time" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" @@ -111,6 +112,122 @@ func main() { return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil }) + // Long-running flow for testing early trace ID transmission and cancellation. + // Each step takes ~5 seconds with nested sub-steps. + // + // Test with: + // curl -d '{"key":"/flow/longRunning/longRunning", "input":{"start": {"input":3}}}' \ + // http://localhost:3100/api/runAction?stream=true + // + // To test cancellation, note the X-Genkit-Trace-Id header and call: + // curl -d '{"traceId":""}' http://localhost:3100/api/cancelAction + type stepResult struct { + Step int `json:"step"` + Timestamp string `json:"timestamp"` + Elapsed int64 `json:"elapsed_ms"` + } + + type longRunningResult struct { + TotalDuration int64 `json:"total_duration_ms"` + StepsCompleted int `json:"steps_completed"` + Timeline []stepResult `json:"timeline"` + } + + genkit.DefineStreamingFlow(g, "longRunning", + func(ctx context.Context, steps int, cb func(context.Context, stepResult) error) (longRunningResult, error) { + if steps <= 0 { + steps = 3 + } + startTime := time.Now() + timeline := make([]stepResult, 0, steps) + + log.Printf("🚀 Starting long-running flow: %d steps × 5s = ~%ds", steps, steps*5) + + for i := 1; i <= steps; i++ { + stepStart := time.Now() + + // Check for cancellation before each step + select { + case <-ctx.Done(): + log.Printf("❌ Cancelled at step %d/%d", i, steps) + return longRunningResult{ + TotalDuration: time.Since(startTime).Milliseconds(), + StepsCompleted: i - 1, + Timeline: timeline, + }, ctx.Err() + default: + } + + log.Printf("[%s] 🔄 Step %d/%d starting...", time.Now().Format(time.RFC3339), i, steps) + + // Nested sub-steps (like the TS version) + _, err := core.Run(ctx, fmt.Sprintf("step-%d-fetch", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" 📡 Fetched data for step %d", i) + return fmt.Sprintf("fetch-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + _, err = core.Run(ctx, fmt.Sprintf("step-%d-process", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" ⚙️ Processed data for step %d", i) + return fmt.Sprintf("process-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + _, err = core.Run(ctx, fmt.Sprintf("step-%d-save", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" 💾 Saved results for step %d", i) + return fmt.Sprintf("save-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + elapsed := time.Since(stepStart).Milliseconds() + log.Printf("[%s] ✅ Step %d/%d completed (%dms)", time.Now().Format(time.RFC3339), i, steps, elapsed) + + result := stepResult{ + Step: i, + Timestamp: time.Now().Format(time.RFC3339), + Elapsed: elapsed, + } + timeline = append(timeline, result) + + // Stream progress if callback provided + if cb != nil { + if err := cb(ctx, result); err != nil { + return longRunningResult{}, err + } + } + } + + totalDuration := time.Since(startTime).Milliseconds() + log.Printf("🎉 Long-running flow completed in %dms", totalDuration) + + return longRunningResult{ + TotalDuration: totalDuration, + StepsCompleted: steps, + Timeline: timeline, + }, nil + }) + mux := http.NewServeMux() for _, a := range genkit.ListFlows(g) { mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index bb8cc688d2..c1826c6ea5 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -376,7 +376,7 @@ export class ReflectionServer { }, }; - // Headers may have been sent already (via onTelemetry), so check before setting status + // Headers may have been sent already (via onTraceStart), so check before setting status if (!res.headersSent) { res.status(500).json(errorResponse); } else {