diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 5c33947640..9882d5a6a2 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -60,6 +60,7 @@ import ( dlmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" fccontroller "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller" fcregistry "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/registry" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" @@ -330,7 +331,10 @@ func (r *Runner) Run(ctx context.Context) error { // --- Admission Control Initialization --- var admissionController requestcontrol.AdmissionController + var locator contracts.PodLocator + locator = requestcontrol.NewDatastorePodLocator(ds) if r.featureGates[flowcontrol.FeatureGate] { + locator = requestcontrol.NewCachedPodLocator(ctx, locator, time.Millisecond*50) setupLog.Info("Initializing experimental Flow Control layer") fcCfg, err := flowControlConfig.ValidateAndApplyDefaults() if err != nil { @@ -342,24 +346,28 @@ func (r *Runner) Run(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to initialize Flow Registry: %w", err) } - fc, err := fccontroller.NewFlowController(ctx, fcCfg.Controller, registry, saturationDetector, setupLog) + fc, err := fccontroller.NewFlowController( + ctx, + fcCfg.Controller, + registry, saturationDetector, + locator, + setupLog, + ) if err != nil { return fmt.Errorf("failed to initialize Flow Controller: %w", err) } go registry.Run(ctx) - admissionController = requestcontrol.NewFlowControlAdmissionController(saturationDetector, fc) + admissionController = requestcontrol.NewFlowControlAdmissionController(fc) } else { setupLog.Info("Experimental Flow Control layer is disabled, using legacy admission control") - admissionController = requestcontrol.NewLegacyAdmissionController(saturationDetector) + admissionController = requestcontrol.NewLegacyAdmissionController(saturationDetector, locator) } - locator := requestcontrol.NewDatastorePodLocator(ds) - cachedLocator := requestcontrol.NewCachedPodLocator(ctx, locator, time.Millisecond*50) director := requestcontrol.NewDirectorWithConfig( ds, scheduler, admissionController, - cachedLocator, + locator, r.requestControlConfig) // --- Setup ExtProc Server Runner --- diff --git a/pkg/epp/flowcontrol/contracts/mocks/mocks.go b/pkg/epp/flowcontrol/contracts/mocks/mocks.go index 49205ee1ec..c042f5aa46 100644 --- a/pkg/epp/flowcontrol/contracts/mocks/mocks.go +++ b/pkg/epp/flowcontrol/contracts/mocks/mocks.go @@ -41,6 +41,8 @@ import ( typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" ) +// --- RegistryShard Mocks --- + // MockRegistryShard is a simple "stub-style" mock for testing. // Its methods are implemented as function fields (e.g., `IDFunc`). A test can inject behavior by setting the desired // function field in the test setup. If a func is nil, the method will return a zero value. @@ -111,6 +113,8 @@ func (m *MockRegistryShard) Stats() contracts.ShardStats { return contracts.ShardStats{} } +// --- Dependency Mocks --- + // MockSaturationDetector is a simple "stub-style" mock for testing. type MockSaturationDetector struct { IsSaturatedFunc func(ctx context.Context, candidatePods []metrics.PodMetrics) bool @@ -123,6 +127,30 @@ func (m *MockSaturationDetector) IsSaturated(ctx context.Context, candidatePods return false } +// MockPodLocator provides a mock implementation of the contracts.PodLocator interface. +// It allows tests to control the exact set of pods returned for a given request. +type MockPodLocator struct { + // LocateFunc allows injecting custom logic. + LocateFunc func(ctx context.Context, requestMetadata map[string]any) []metrics.PodMetrics + // Pods is a static return value used if LocateFunc is nil. + Pods []metrics.PodMetrics +} + +func (m *MockPodLocator) Locate(ctx context.Context, requestMetadata map[string]any) []metrics.PodMetrics { + if m.LocateFunc != nil { + return m.LocateFunc(ctx, requestMetadata) + } + // Return copy to be safe + if m.Pods == nil { + return nil + } + result := make([]metrics.PodMetrics, len(m.Pods)) + copy(result, m.Pods) + return result +} + +// --- ManagedQueue Mock --- + // MockManagedQueue is a high-fidelity, thread-safe mock of the `contracts.ManagedQueue` interface, designed // specifically for testing the concurrent `controller/internal.ShardProcessor`. // diff --git a/pkg/epp/flowcontrol/controller/controller.go b/pkg/epp/flowcontrol/controller/controller.go index 751a420f49..29eda4ec58 100644 --- a/pkg/epp/flowcontrol/controller/controller.go +++ b/pkg/epp/flowcontrol/controller/controller.go @@ -60,6 +60,7 @@ type shardProcessorFactory func( ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, + podLocator contracts.PodLocator, clock clock.WithTicker, cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, @@ -95,6 +96,7 @@ type FlowController struct { config Config registry registryClient saturationDetector contracts.SaturationDetector + podLocator contracts.PodLocator clock clock.WithTicker logger logr.Logger shardProcessorFactory shardProcessorFactory @@ -126,6 +128,7 @@ func NewFlowController( config Config, registry contracts.FlowRegistry, sd contracts.SaturationDetector, + podLocator contracts.PodLocator, logger logr.Logger, opts ...flowControllerOption, ) (*FlowController, error) { @@ -133,6 +136,7 @@ func NewFlowController( config: config, registry: registry, saturationDetector: sd, + podLocator: podLocator, clock: clock.RealClock{}, logger: logger.WithName("flow-controller"), parentCtx: ctx, @@ -142,6 +146,7 @@ func NewFlowController( ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, + podLocator contracts.PodLocator, clock clock.WithTicker, cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, @@ -151,6 +156,7 @@ func NewFlowController( ctx, shard, saturationDetector, + podLocator, clock, cleanupSweepInterval, enqueueChannelBufferSize, @@ -448,6 +454,7 @@ func (fc *FlowController) getOrStartWorker(shard contracts.RegistryShard) *manag processorCtx, shard, fc.saturationDetector, + fc.podLocator, fc.clock, fc.config.ExpiryCleanupInterval, fc.config.EnqueueChannelBufferSize, diff --git a/pkg/epp/flowcontrol/controller/controller_test.go b/pkg/epp/flowcontrol/controller/controller_test.go index 9f46374466..7fcaa50edc 100644 --- a/pkg/epp/flowcontrol/controller/controller_test.go +++ b/pkg/epp/flowcontrol/controller/controller_test.go @@ -79,7 +79,6 @@ type testHarness struct { // clock is the clock interface used by the controller. clock clock.WithTicker mockRegistry *mockRegistryClient - mockDetector *mocks.MockSaturationDetector // mockClock provides access to FakeClock methods (Step, HasWaiters) if and only if the underlying clock is a // FakeClock. mockClock *testclock.FakeClock @@ -91,6 +90,7 @@ type testHarness struct { func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness { t.Helper() mockDetector := &mocks.MockSaturationDetector{} + mockPodLocator := &mocks.MockPodLocator{} // Initialize the FakeClock with the current system time. // The controller implementation uses the injected clock to calculate the deadline timestamp,vbut uses the standard @@ -113,7 +113,7 @@ func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *moc withClock(mockClock), withShardProcessorFactory(mockProcessorFactory.new), } - fc, err := NewFlowController(ctx, cfg, registry, mockDetector, logr.Discard(), opts...) + fc, err := NewFlowController(ctx, cfg, registry, mockDetector, mockPodLocator, logr.Discard(), opts...) require.NoError(t, err, "failed to create FlowController for unit test harness") h := &testHarness{ @@ -121,7 +121,6 @@ func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *moc cfg: cfg, clock: mockClock, mockRegistry: registry, - mockDetector: mockDetector, mockClock: mockClock, mockProcessorFactory: mockProcessorFactory, } @@ -133,8 +132,9 @@ func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *moc func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness { t.Helper() mockDetector := &mocks.MockSaturationDetector{} - // Align FakeClock with system time. See explanation in newUnitHarness. + mockPodLocator := &mocks.MockPodLocator{} + // Align FakeClock with system time. See explanation in newUnitHarness. mockClock := testclock.NewFakeClock(time.Now()) if registry == nil { registry = &mockRegistryClient{} @@ -144,7 +144,7 @@ func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, regist withRegistryClient(registry), withClock(mockClock), } - fc, err := NewFlowController(ctx, cfg, registry, mockDetector, logr.Discard(), opts...) + fc, err := NewFlowController(ctx, cfg, registry, mockDetector, mockPodLocator, logr.Discard(), opts...) require.NoError(t, err, "failed to create FlowController for integration test harness") h := &testHarness{ @@ -152,7 +152,6 @@ func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, regist cfg: cfg, clock: mockClock, mockRegistry: registry, - mockDetector: mockDetector, mockClock: mockClock, } return h @@ -247,6 +246,7 @@ func (f *mockShardProcessorFactory) new( _ context.Context, // The factory does not use the lifecycle context; it's passed to the processor's Run method later. shard contracts.RegistryShard, _ contracts.SaturationDetector, + _ contracts.PodLocator, _ clock.WithTicker, _ time.Duration, _ int, @@ -1001,6 +1001,7 @@ func TestFlowController_WorkerManagement(t *testing.T) { ctx context.Context, // The context created by getOrStartWorker for the potential new processor. shard contracts.RegistryShard, _ contracts.SaturationDetector, + _ contracts.PodLocator, _ clock.WithTicker, _ time.Duration, _ int, diff --git a/pkg/epp/flowcontrol/controller/internal/processor.go b/pkg/epp/flowcontrol/controller/internal/processor.go index 489f421be5..ae419af25a 100644 --- a/pkg/epp/flowcontrol/controller/internal/processor.go +++ b/pkg/epp/flowcontrol/controller/internal/processor.go @@ -65,6 +65,7 @@ var ErrProcessorBusy = errors.New("shard processor is busy") type ShardProcessor struct { shard contracts.RegistryShard saturationDetector contracts.SaturationDetector + podLocator contracts.PodLocator clock clock.WithTicker cleanupSweepInterval time.Duration logger logr.Logger @@ -86,6 +87,7 @@ func NewShardProcessor( ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, + podLocator contracts.PodLocator, clock clock.WithTicker, cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, @@ -94,6 +96,7 @@ func NewShardProcessor( return &ShardProcessor{ shard: shard, saturationDetector: saturationDetector, + podLocator: podLocator, clock: clock, cleanupSweepInterval: cleanupSweepInterval, logger: logger, @@ -307,8 +310,8 @@ func (sp *ShardProcessor) dispatchCycle(ctx context.Context) bool { // --- Viability Check (Saturation/HoL Blocking) --- req := item.OriginalRequest() - candidatePods := req.CandidatePodsForScheduling() - if sp.saturationDetector.IsSaturated(ctx, candidatePods) { + candidates := sp.podLocator.Locate(ctx, req.GetMetadata()) + if sp.saturationDetector.IsSaturated(ctx, candidates) { sp.logger.V(logutil.DEBUG).Info("Policy's chosen item is saturated; enforcing HoL blocking.", "flowKey", req.FlowKey(), "reqID", req.ID(), "priorityName", originalBand.PriorityName()) // Stop the dispatch cycle entirely to respect strict policy decision and prevent priority inversion where diff --git a/pkg/epp/flowcontrol/controller/internal/processor_test.go b/pkg/epp/flowcontrol/controller/internal/processor_test.go index ff75130a71..ea7a1f27ac 100644 --- a/pkg/epp/flowcontrol/controller/internal/processor_test.go +++ b/pkg/epp/flowcontrol/controller/internal/processor_test.go @@ -75,6 +75,7 @@ type testHarness struct { clock *testclock.FakeClock logger logr.Logger saturationDetector *mocks.MockSaturationDetector + podLocator *mocks.MockPodLocator // --- Centralized Mock State --- // The harness's mutex protects the single source of truth for all mock state. @@ -96,6 +97,7 @@ func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarn clock: testclock.NewFakeClock(time.Now()), logger: logr.Discard(), saturationDetector: &mocks.MockSaturationDetector{}, + podLocator: &mocks.MockPodLocator{Pods: []metrics.PodMetrics{&metrics.FakePodMetrics{}}}, startSignal: make(chan struct{}), queues: make(map[types.FlowKey]*mocks.MockManagedQueue), priorityFlows: make(map[int][]types.FlowKey), @@ -123,6 +125,7 @@ func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarn h.ctx, h, h.saturationDetector, + h.podLocator, h.clock, expiryCleanupInterval, 100, diff --git a/pkg/epp/flowcontrol/types/mocks/mocks.go b/pkg/epp/flowcontrol/types/mocks/mocks.go index 5fabf36831..2e12b04499 100644 --- a/pkg/epp/flowcontrol/types/mocks/mocks.go +++ b/pkg/epp/flowcontrol/types/mocks/mocks.go @@ -21,20 +21,19 @@ package mocks import ( "time" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" ) -// MockFlowControlRequest provides a mock implementation of the `types.FlowControlRequest` interface. +// MockFlowControlRequest provides a mock implementation of the types.FlowControlRequest interface. type MockFlowControlRequest struct { - FlowKeyV types.FlowKey - ByteSizeV uint64 - InitialEffectiveTTLV time.Duration - IDV string - CandidatePodsForSchedulingV []*metrics.FakePodMetrics + FlowKeyV types.FlowKey + ByteSizeV uint64 + InitialEffectiveTTLV time.Duration + IDV string + MetadataV map[string]any } -// NewMockFlowControlRequest creates a new `MockFlowControlRequest` instance. +// NewMockFlowControlRequest creates a new MockFlowControlRequest instance. func NewMockFlowControlRequest( byteSize uint64, id string, @@ -44,6 +43,7 @@ func NewMockFlowControlRequest( ByteSizeV: byteSize, IDV: id, FlowKeyV: key, + MetadataV: make(map[string]any), } } @@ -51,14 +51,7 @@ func (m *MockFlowControlRequest) FlowKey() types.FlowKey { return m. func (m *MockFlowControlRequest) ByteSize() uint64 { return m.ByteSizeV } func (m *MockFlowControlRequest) InitialEffectiveTTL() time.Duration { return m.InitialEffectiveTTLV } func (m *MockFlowControlRequest) ID() string { return m.IDV } - -func (m *MockFlowControlRequest) CandidatePodsForScheduling() []metrics.PodMetrics { - pods := make([]metrics.PodMetrics, 0, len(m.CandidatePodsForSchedulingV)) - for i, pod := range m.CandidatePodsForSchedulingV { - pods[i] = pod - } - return pods -} +func (m *MockFlowControlRequest) GetMetadata() map[string]any { return m.MetadataV } var _ types.FlowControlRequest = &MockFlowControlRequest{} diff --git a/pkg/epp/flowcontrol/types/request.go b/pkg/epp/flowcontrol/types/request.go index e427b0abaa..61ac32049e 100644 --- a/pkg/epp/flowcontrol/types/request.go +++ b/pkg/epp/flowcontrol/types/request.go @@ -18,8 +18,6 @@ package types import ( "time" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" ) // FlowControlRequest is the contract for an incoming request submitted to the `controller.FlowController`. It @@ -45,15 +43,15 @@ type FlowControlRequest interface { // applied. InitialEffectiveTTL() time.Duration - // CandidatePodsForScheduling passes through a set of candidate pods a request may be admitted to. - // This is necessary for invoking `contracts.SaturationDetector.IsSaturated`, but it is otherwise unused in the Flow - // Control system. - CandidatePodsForScheduling() []metrics.PodMetrics - // ID returns an optional, user-facing unique identifier for this specific request. It is intended for logging, // tracing, and observability. The `controller.FlowController` does not use this ID for dispatching decisions; it uses // the internal, opaque `QueueItemHandle`. ID() string + + // GetMetadata returns the opaque metadata associated with the request (e.g., header-derived context, subset filters). + // This data is passed transparently to components like the contracts.PodLocator to resolve resources (candidate pods) + // lazily during the dispatch cycle. + GetMetadata() map[string]any } // QueueItemHandle is an opaque handle to an item that has been successfully added to a `framework.SafeQueue`. It acts diff --git a/pkg/epp/requestcontrol/admission.go b/pkg/epp/requestcontrol/admission.go index 69fd5adf8b..2a806b56df 100644 --- a/pkg/epp/requestcontrol/admission.go +++ b/pkg/epp/requestcontrol/admission.go @@ -20,9 +20,11 @@ import ( "context" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" @@ -40,7 +42,6 @@ type AdmissionController interface { // Args: // ctx: The request context, carrying deadlines, cancellation signals, and logger. // reqCtx: The handlers.RequestContext containing details about the incoming request. - // candidatePods: A list of potential backend pods that can serve the request. // priority: The priority level of the request, as determined by the InferenceObjective. // // Returns: @@ -49,7 +50,6 @@ type AdmissionController interface { Admit( ctx context.Context, reqCtx *handlers.RequestContext, - candidatePods []backendmetrics.PodMetrics, priority int, ) error } @@ -65,18 +65,17 @@ type flowController interface { EnqueueAndWait(ctx context.Context, req types.FlowControlRequest) (types.QueueOutcome, error) } -// rejectIfSheddableAndSaturated checks if a request should be immediately rejected because it's sheddable -// (priority < 0) and the system is saturated. +// rejectIfSheddableAndSaturated checks if a request should be immediately rejected. func rejectIfSheddableAndSaturated( ctx context.Context, sd saturationDetector, + locator contracts.PodLocator, reqCtx *handlers.RequestContext, - candidatePods []backendmetrics.PodMetrics, priority int, + logger logr.Logger, ) error { if requtil.IsSheddable(priority) { - logger := log.FromContext(ctx) - if sd.IsSaturated(ctx, candidatePods) { + if sd.IsSaturated(ctx, locator.Locate(ctx, reqCtx.Request.Metadata)) { logger.V(logutil.TRACE).Info("Request rejected: system saturated and request is sheddable", "requestID", reqCtx.SchedulingRequest.RequestId) return errutil.Error{ @@ -95,11 +94,18 @@ func rejectIfSheddableAndSaturated( // saturated. Non-sheddable requests always bypass the saturation check. type LegacyAdmissionController struct { saturationDetector saturationDetector + podLocator contracts.PodLocator } // NewLegacyAdmissionController creates a new LegacyAdmissionController. -func NewLegacyAdmissionController(sd saturationDetector) *LegacyAdmissionController { - return &LegacyAdmissionController{saturationDetector: sd} +func NewLegacyAdmissionController( + sd saturationDetector, + pl contracts.PodLocator, +) *LegacyAdmissionController { + return &LegacyAdmissionController{ + saturationDetector: sd, + podLocator: pl, + } } // Admit implements the AdmissionController interface for the legacy strategy. @@ -107,13 +113,18 @@ func NewLegacyAdmissionController(sd saturationDetector) *LegacyAdmissionControl func (lac *LegacyAdmissionController) Admit( ctx context.Context, reqCtx *handlers.RequestContext, - candidatePods []backendmetrics.PodMetrics, priority int, ) error { logger := log.FromContext(ctx) logger.V(logutil.TRACE).Info("Executing LegacyAdmissionController", "priority", priority, "fairnessID", reqCtx.FairnessID) - if err := rejectIfSheddableAndSaturated(ctx, lac.saturationDetector, reqCtx, candidatePods, priority); err != nil { + if err := rejectIfSheddableAndSaturated( + ctx, + lac.saturationDetector, + lac.podLocator, + reqCtx, priority, + logger, + ); err != nil { return err } logger.V(logutil.TRACE).Info("Request admitted", "requestID", reqCtx.SchedulingRequest.RequestId) @@ -123,19 +134,15 @@ func (lac *LegacyAdmissionController) Admit( // --- FlowControlAdmissionController --- // FlowControlAdmissionController delegates admission decisions to the Flow Control layer. -// It first checks if the request is sheddable and the system is saturated, rejecting immediately if both conditions are -// true. Otherwise, it uses the provided flowController to enqueue the request and await an outcome. +// It uses the provided Flow Controller to enqueue the request and await an outcome. type FlowControlAdmissionController struct { - saturationDetector saturationDetector - flowController flowController + flowController flowController } // NewFlowControlAdmissionController creates a new FlowControlAdmissionController. -// It requires a SaturationDetector and a flowController instance. -func NewFlowControlAdmissionController(sd saturationDetector, fc flowController) *FlowControlAdmissionController { +func NewFlowControlAdmissionController(fc flowController) *FlowControlAdmissionController { return &FlowControlAdmissionController{ - saturationDetector: sd, - flowController: fc, + flowController: fc, } } @@ -144,24 +151,18 @@ func NewFlowControlAdmissionController(sd saturationDetector, fc flowController) func (fcac *FlowControlAdmissionController) Admit( ctx context.Context, reqCtx *handlers.RequestContext, - candidatePods []backendmetrics.PodMetrics, priority int, ) error { logger := log.FromContext(ctx) logger.V(logutil.TRACE).Info("Executing FlowControlAdmissionController", "requestID", reqCtx.SchedulingRequest.RequestId, "priority", priority, "fairnessID", reqCtx.FairnessID) - if err := rejectIfSheddableAndSaturated(ctx, fcac.saturationDetector, reqCtx, candidatePods, priority); err != nil { - return err - } - - logger.V(logutil.TRACE).Info("Request proceeding to flow control", "requestID", reqCtx.SchedulingRequest.RequestId) fcReq := &flowControlRequest{ requestID: reqCtx.SchedulingRequest.RequestId, fairnessID: reqCtx.FairnessID, priority: priority, requestByteSize: uint64(reqCtx.RequestSize), - candidatePods: candidatePods, + reqMetadata: reqCtx.Request.Metadata, } outcome, err := fcac.flowController.EnqueueAndWait(ctx, fcReq) @@ -176,7 +177,7 @@ type flowControlRequest struct { fairnessID string priority int requestByteSize uint64 - candidatePods []backendmetrics.PodMetrics + reqMetadata map[string]any } var _ types.FlowControlRequest = &flowControlRequest{} @@ -184,12 +185,12 @@ var _ types.FlowControlRequest = &flowControlRequest{} func (r *flowControlRequest) ID() string { return r.requestID } func (r *flowControlRequest) InitialEffectiveTTL() time.Duration { return 0 } // Use controller default. func (r *flowControlRequest) ByteSize() uint64 { return r.requestByteSize } -func (r *flowControlRequest) CandidatePodsForScheduling() []backendmetrics.PodMetrics { - return r.candidatePods -} func (r *flowControlRequest) FlowKey() types.FlowKey { return types.FlowKey{ID: r.fairnessID, Priority: r.priority} } +func (r *flowControlRequest) GetMetadata() map[string]any { + return r.reqMetadata +} // translateFlowControlOutcome maps the context-rich outcome of the Flow Control layer to the public errutil.Error // contract used by the Director. diff --git a/pkg/epp/requestcontrol/admission_test.go b/pkg/epp/requestcontrol/admission_test.go index 0857782009..388a5c9924 100644 --- a/pkg/epp/requestcontrol/admission_test.go +++ b/pkg/epp/requestcontrol/admission_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks" fctypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" @@ -34,14 +35,6 @@ import ( // --- Mocks --- -type mockSaturationDetector struct { - isSaturated bool -} - -func (m *mockSaturationDetector) IsSaturated(_ context.Context, _ []backendmetrics.PodMetrics) bool { - return m.isSaturated -} - type mockFlowController struct { outcome fctypes.QueueOutcome err error @@ -56,18 +49,25 @@ func (m *mockFlowController) EnqueueAndWait( return m.outcome, m.err } +// --- Legacy Controller Tests --- + func TestLegacyAdmissionController_Admit(t *testing.T) { t.Parallel() ctx := logutil.NewTestLoggerIntoContext(context.Background()) - candidatePods := []backendmetrics.PodMetrics{} reqCtx := &handlers.RequestContext{ SchedulingRequest: &schedulingtypes.LLMRequest{RequestId: "test-req"}, + Request: &handlers.Request{ + Metadata: map[string]any{}, + }, } + mockPods := []backendmetrics.PodMetrics{&backendmetrics.FakePodMetrics{}} + testCases := []struct { name string priority int isSaturated bool + locatorPods []backendmetrics.PodMetrics expectErr bool expectErrCode string expectErrSubstr string @@ -76,18 +76,30 @@ func TestLegacyAdmissionController_Admit(t *testing.T) { name: "non_sheddable_saturated_admit", priority: 0, isSaturated: true, + locatorPods: mockPods, expectErr: false, }, { name: "sheddable_not_saturated_admit", priority: -1, isSaturated: false, + locatorPods: mockPods, expectErr: false, }, { name: "sheddable_saturated_reject", priority: -1, isSaturated: true, + locatorPods: mockPods, + expectErr: true, + expectErrCode: errutil.InferencePoolResourceExhausted, + expectErrSubstr: "system saturated, sheddable request dropped", + }, + { + name: "sheddable_no_pods_reject", + priority: -1, + isSaturated: true, + locatorPods: []backendmetrics.PodMetrics{}, expectErr: true, expectErrCode: errutil.InferencePoolResourceExhausted, expectErrSubstr: "system saturated, sheddable request dropped", @@ -97,10 +109,13 @@ func TestLegacyAdmissionController_Admit(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - saturationDetector := &mockSaturationDetector{isSaturated: tc.isSaturated} - ac := NewLegacyAdmissionController(saturationDetector) + saturationDetector := &mocks.MockSaturationDetector{ + IsSaturatedFunc: func(context.Context, []backendmetrics.PodMetrics) bool { return tc.isSaturated }, + } + locator := &mocks.MockPodLocator{Pods: tc.locatorPods} + ac := NewLegacyAdmissionController(saturationDetector, locator) - err := ac.Admit(ctx, reqCtx, candidatePods, tc.priority) + err := ac.Admit(ctx, reqCtx, tc.priority) if !tc.expectErr { assert.NoError(t, err, "Admit() should not have returned an error for scenario: %s", tc.name) @@ -116,9 +131,10 @@ func TestLegacyAdmissionController_Admit(t *testing.T) { } } +// --- Flow Control Controller Tests --- + func TestFlowControlRequestAdapter(t *testing.T) { t.Parallel() - candidatePods := []backendmetrics.PodMetrics{&backendmetrics.FakePodMetrics{}} testCases := []struct { name string @@ -146,59 +162,46 @@ func TestFlowControlRequestAdapter(t *testing.T) { fairnessID: tc.fairnessID, priority: tc.priority, requestByteSize: tc.requestByteSize, - candidatePods: candidatePods, } assert.Equal(t, tc.requestID, fcReq.ID(), "ID() mismatch") assert.Equal(t, tc.requestByteSize, fcReq.ByteSize(), "ByteSize() mismatch") - assert.Equal(t, candidatePods, fcReq.CandidatePodsForScheduling(), "CandidatePodsForScheduling() mismatch") assert.Equal(t, tc.expectFlowKey, fcReq.FlowKey(), "FlowKey() mismatch") assert.Zero(t, fcReq.InitialEffectiveTTL(), "InitialEffectiveTTL() should be zero") }) } } + func TestFlowControlAdmissionController_Admit(t *testing.T) { t.Parallel() ctx := logutil.NewTestLoggerIntoContext(context.Background()) - candidatePods := []backendmetrics.PodMetrics{} - reqCtx := &handlers.RequestContext{ SchedulingRequest: &schedulingtypes.LLMRequest{RequestId: "test-req"}, + Request: &handlers.Request{ + Metadata: map[string]any{}, + }, } testCases := []struct { name string priority int - isSaturated bool fcOutcome fctypes.QueueOutcome fcErr error expectErr bool expectErrCode string expectErrSubstr string - expectFCSkipped bool }{ { - name: "sheddable_saturated_reject", - priority: -1, - isSaturated: true, - expectErr: true, - expectErrCode: errutil.InferencePoolResourceExhausted, - expectErrSubstr: "system saturated, sheddable request dropped", - expectFCSkipped: true, - }, - { - name: "sheddable_not_saturated_dispatch", - priority: -1, - isSaturated: false, - fcOutcome: fctypes.QueueOutcomeDispatched, - expectErr: false, + name: "sheddable_dispatched", + priority: -1, + fcOutcome: fctypes.QueueOutcomeDispatched, + expectErr: false, }, { - name: "non_sheddable_saturated_dispatch", - priority: 0, - isSaturated: true, - fcOutcome: fctypes.QueueOutcomeDispatched, - expectErr: false, + name: "non_sheddable_dispatched", + priority: 0, + fcOutcome: fctypes.QueueOutcomeDispatched, + expectErr: false, }, { name: "fc_reject_capacity", @@ -255,17 +258,12 @@ func TestFlowControlAdmissionController_Admit(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - sd := &mockSaturationDetector{isSaturated: tc.isSaturated} fc := &mockFlowController{outcome: tc.fcOutcome, err: tc.fcErr} - ac := NewFlowControlAdmissionController(sd, fc) + ac := NewFlowControlAdmissionController(fc) - err := ac.Admit(ctx, reqCtx, candidatePods, tc.priority) + err := ac.Admit(ctx, reqCtx, tc.priority) - if tc.expectFCSkipped { - assert.False(t, fc.called, "FlowController should not have been called for scenario: %s", tc.name) - } else { - assert.True(t, fc.called, "FlowController should have been called for scenario: %s", tc.name) - } + assert.True(t, fc.called, "FlowController should have been called for scenario: %s", tc.name) if !tc.expectErr { assert.NoError(t, err, "Admit() returned an unexpected error for scenario: %s", tc.name) diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 19e69cf2b8..e4af1c4bb3 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -159,15 +159,17 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo ctx = log.IntoContext(ctx, logger) logger.V(logutil.DEBUG).Info("LLM request assembled") - // Get candidate pods for scheduling - candidatePods := d.podLocator.Locate(ctx, reqCtx.Request.Metadata) - if len(candidatePods) == 0 { - return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"} - } - if err := d.admissionController.Admit(ctx, reqCtx, candidatePods, *infObjective.Spec.Priority); err != nil { + if err := d.admissionController.Admit(ctx, reqCtx, *infObjective.Spec.Priority); err != nil { logger.V(logutil.DEFAULT).Info("Request rejected by admission control", "error", err) return reqCtx, err } + candidatePods := d.podLocator.Locate(ctx, reqCtx.Request.Metadata) + if len(candidatePods) == 0 { + return reqCtx, errutil.Error{ + Code: errutil.ServiceUnavailable, + Msg: "failed to find candidate pods for serving the request", + } + } snapshotOfCandidatePods := d.toSchedulerPodMetrics(candidatePods) // Prepare per request data by running PrepareData plugins. diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 58372e766b..9b583cf2cc 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -60,12 +60,7 @@ type mockAdmissionController struct { admitErr error } -func (m *mockAdmissionController) Admit( - _ context.Context, - _ *handlers.RequestContext, - _ []backendmetrics.PodMetrics, - _ int, -) error { +func (m *mockAdmissionController) Admit(context.Context, *handlers.RequestContext, int) error { return m.admitErr } @@ -673,6 +668,7 @@ func TestDirector_HandleRequest(t *testing.T) { director.datastore = mockDs director.podLocator = NewCachedPodLocator(context.Background(), NewDatastorePodLocator(mockDs), time.Minute) } + reqCtx := &handlers.RequestContext{ Request: &handlers.Request{ // Create a copy of the map for each test run to avoid mutation issues. diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index a1aff76e86..a3557a6ed4 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -1240,14 +1240,13 @@ func BeforeSuite() func() { } detector := saturationdetector.NewDetector(sdConfig, logger.WithName("saturation-detector")) serverRunner.SaturationDetector = detector - admissionController := requestcontrol.NewLegacyAdmissionController(detector) locator := requestcontrol.NewDatastorePodLocator(serverRunner.Datastore) - cachedLocator := requestcontrol.NewCachedPodLocator(context.Background(), locator, time.Millisecond*50) + admissionController := requestcontrol.NewLegacyAdmissionController(detector, locator) serverRunner.Director = requestcontrol.NewDirectorWithConfig( serverRunner.Datastore, scheduler, admissionController, - cachedLocator, + locator, requestcontrol.NewConfig(), ) serverRunner.SecureServing = false