Skip to content
This repository was archived by the owner on Sep 30, 2024. It is now read-only.

Commit e5388c3

Browse files
rafaxwilldollman
andauthored
Support rejecting invalid requests (#57414)
* Only record prompt prefixes for .com actors * Support blocking requests (behind a flag) * Consistent cases * Consistent naming * Bad merge resolution * Update cmd/cody-gateway/shared/config.go Co-authored-by: Will Dollman <will.dollman@sourcegraph.com> * PR feedback * Move flagging result, add GetModel() * Export events about blocked requests * Fix tests * Minor fixes --------- Co-authored-by: Will Dollman <will.dollman@sourcegraph.com>
1 parent d262812 commit e5388c3

File tree

9 files changed

+145
-57
lines changed

9 files changed

+145
-57
lines changed

cmd/cody-gateway/internal/httpapi/completions/anthropic.go

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ const anthropicAPIURL = "https://api.anthropic.com/v1/complete"
2727
const (
2828
logPromptPrefixLength = 250
2929

30-
promptTokenLimit = 18000
31-
responseTokenLimit = 1000
30+
promptTokenFlaggingLimit = 18000
31+
responseTokenFlaggingLimit = 1000
32+
33+
promptTokenBlockingLimit = 20000
34+
responseTokenBlockingLimit = 1000
3235
)
3336

3437
func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, promptRegexps []*regexp.Regexp) (*flaggingResult, error) {
@@ -44,7 +47,7 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, pro
4447
}
4548

4649
// If this request has a very high token count for responses, then flag it.
47-
if ar.MaxTokensToSample > responseTokenLimit {
50+
if ar.MaxTokensToSample > responseTokenFlaggingLimit {
4851
reasons = append(reasons, "high_max_tokens_to_sample")
4952
}
5053

@@ -53,11 +56,16 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, pro
5356
if err != nil {
5457
return &flaggingResult{}, errors.Wrap(err, "tokenize prompt")
5558
}
56-
if tokenCount > promptTokenLimit {
59+
if tokenCount > promptTokenFlaggingLimit {
5760
reasons = append(reasons, "high_prompt_token_count")
5861
}
5962

6063
if len(reasons) > 0 {
64+
blocked := false
65+
if tokenCount > promptTokenBlockingLimit || ar.MaxTokensToSample > responseTokenBlockingLimit {
66+
blocked = true
67+
}
68+
6169
promptPrefix := ar.Prompt
6270
if len(promptPrefix) > logPromptPrefixLength {
6371
promptPrefix = promptPrefix[0:logPromptPrefixLength]
@@ -67,6 +75,7 @@ func isFlaggedAnthropicRequest(tk *tokenizer.Tokenizer, ar anthropicRequest, pro
6775
maxTokensToSample: int(ar.MaxTokensToSample),
6876
promptPrefix: promptPrefix,
6977
promptTokenCount: tokenCount,
78+
shouldBlock: blocked,
7079
}, nil
7180
}
7281

@@ -99,6 +108,7 @@ func NewAnthropicHandler(
99108
maxTokensToSample int,
100109
promptRecorder PromptRecorder,
101110
allowedPromptPatterns []string,
111+
requestBlockingEnabled bool,
102112
) (http.Handler, error) {
103113
// Tokenizer only needs to be initialized once, and can be shared globally.
104114
anthropicTokenizer, err := tokenizer.NewAnthropicClaudeTokenizer()
@@ -132,6 +142,9 @@ func NewAnthropicHandler(
132142
if err := promptRecorder.Record(ctx, ar.Prompt); err != nil {
133143
logger.Warn("failed to record flagged prompt", log.Error(err))
134144
}
145+
if requestBlockingEnabled && result.shouldBlock {
146+
return http.StatusBadRequest, result, errors.Errorf("request blocked - if you think this is a mistake, please contact support@sourcegraph.com")
147+
}
135148
return 0, result, nil
136149
}
137150

@@ -249,6 +262,10 @@ type anthropicRequest struct {
249262
promptTokens *anthropicTokenCount
250263
}
251264

265+
func (ar anthropicRequest) GetModel() string {
266+
return ar.Model
267+
}
268+
252269
type anthropicTokenCount struct {
253270
count int
254271
err error
@@ -276,15 +293,3 @@ type anthropicResponse struct {
276293
Completion string `json:"completion,omitempty"`
277294
StopReason string `json:"stop_reason,omitempty"`
278295
}
279-
280-
type flaggingResult struct {
281-
blocked bool
282-
reasons []string
283-
promptPrefix string
284-
maxTokensToSample int
285-
promptTokenCount int
286-
}
287-
288-
func (f *flaggingResult) IsFlagged() bool {
289-
return f != nil
290-
}

cmd/cody-gateway/internal/httpapi/completions/anthropic_test.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func TestIsFlaggedAnthropicRequest(t *testing.T) {
2626
result, err := isFlaggedAnthropicRequest(tk, ar, []*regexp.Regexp{regexp.MustCompile(validPreamble)})
2727
require.NoError(t, err)
2828
require.True(t, result.IsFlagged())
29+
require.False(t, result.shouldBlock)
2930
require.Contains(t, result.reasons, "unknown_prompt")
3031
})
3132

@@ -41,21 +42,38 @@ func TestIsFlaggedAnthropicRequest(t *testing.T) {
4142
result, err := isFlaggedAnthropicRequest(tk, ar, []*regexp.Regexp{})
4243
require.NoError(t, err)
4344
require.True(t, result.IsFlagged())
45+
require.True(t, result.shouldBlock)
4446
require.Contains(t, result.reasons, "high_max_tokens_to_sample")
4547
require.Equal(t, int32(result.maxTokensToSample), ar.MaxTokensToSample)
4648
})
47-
t.Run("high prompt token count", func(t *testing.T) {
49+
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
4850
tokenLengths, err := tk.Tokenize(validPreamble)
4951
require.NoError(t, err)
5052

5153
validPreambleTokens := len(tokenLengths)
52-
longPrompt := strings.Repeat("word ", promptTokenLimit+1)
54+
longPrompt := strings.Repeat("word ", promptTokenFlaggingLimit+1)
5355
ar := anthropicRequest{Model: "claude-2", Prompt: validPreamble + " " + longPrompt}
5456
result, err := isFlaggedAnthropicRequest(tk, ar, []*regexp.Regexp{regexp.MustCompile(validPreamble)})
5557
require.NoError(t, err)
5658
require.True(t, result.IsFlagged())
59+
require.False(t, result.shouldBlock)
5760
require.Contains(t, result.reasons, "high_prompt_token_count")
58-
require.Equal(t, result.promptTokenCount, validPreambleTokens+1+promptTokenLimit+1)
61+
require.Equal(t, result.promptTokenCount, validPreambleTokens+1+promptTokenFlaggingLimit+1)
62+
})
63+
64+
t.Run("high prompt token count (below block limit)", func(t *testing.T) {
65+
tokenLengths, err := tk.Tokenize(validPreamble)
66+
require.NoError(t, err)
67+
68+
validPreambleTokens := len(tokenLengths)
69+
longPrompt := strings.Repeat("word ", promptTokenBlockingLimit+1)
70+
ar := anthropicRequest{Model: "claude-2", Prompt: validPreamble + " " + longPrompt}
71+
result, err := isFlaggedAnthropicRequest(tk, ar, []*regexp.Regexp{regexp.MustCompile(validPreamble)})
72+
require.NoError(t, err)
73+
require.True(t, result.IsFlagged())
74+
require.True(t, result.shouldBlock)
75+
require.Contains(t, result.reasons, "high_prompt_token_count")
76+
require.Equal(t, result.promptTokenCount, validPreambleTokens+1+promptTokenBlockingLimit+1)
5977
})
6078
}
6179

cmd/cody-gateway/internal/httpapi/completions/fireworks.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ type fireworksRequest struct {
140140
Stop []string `json:"stop,omitempty"`
141141
}
142142

143+
func (fr fireworksRequest) GetModel() string {
144+
return fr.Model
145+
}
146+
143147
type fireworksResponse struct {
144148
Choices []struct {
145149
Text string `json:"text"`

cmd/cody-gateway/internal/httpapi/completions/openai.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ type openaiRequest struct {
157157
User string `json:"user,omitempty"`
158158
}
159159

160+
func (r openaiRequest) GetModel() string {
161+
return r.Model
162+
}
163+
160164
type openaiUsage struct {
161165
PromptTokens int `json:"prompt_tokens"`
162166
CompletionTokens int `json:"completion_tokens"`

cmd/cody-gateway/internal/httpapi/completions/upstream.go

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ type upstreamHandlerMethods[ReqT UpstreamRequest] struct {
7373
parseResponseAndUsage func(log.Logger, ReqT, io.Reader) (promptUsage, completionUsage usageStats)
7474
}
7575

76-
type UpstreamRequest interface{}
76+
type UpstreamRequest interface {
77+
GetModel() string
78+
}
7779

7880
func makeUpstreamHandler[ReqT UpstreamRequest](
7981
baseLogger log.Logger,
@@ -167,6 +169,37 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
167169
if status == 0 {
168170
response.JSONError(logger, w, http.StatusBadRequest, errors.Wrap(err, "invalid request"))
169171
}
172+
if flaggingResult.IsFlagged() && flaggingResult.shouldBlock {
173+
requestMetadata := getFlaggingMetadata(flaggingResult, act)
174+
err := eventLogger.LogEvent(
175+
r.Context(),
176+
events.Event{
177+
Name: codygateway.EventNameRequestBlocked,
178+
Source: act.Source.Name(),
179+
Identifier: act.ID,
180+
Metadata: mergeMaps(requestMetadata, map[string]any{
181+
codygateway.CompletionsEventFeatureMetadataField: feature,
182+
"model": fmt.Sprintf("%s/%s", upstreamName, body.GetModel()),
183+
"provider": upstreamName,
184+
185+
// Response details
186+
"resolved_status_code": status,
187+
188+
// Request metadata
189+
"prompt_token_count": flaggingResult.promptTokenCount,
190+
"max_tokens_to_sample": flaggingResult.maxTokensToSample,
191+
192+
// Actor details, specific to the actor Source
193+
"sg_actor_id": sgActorID,
194+
"sg_actor_anonymous_uid": sgActorAnonymousUID,
195+
}),
196+
},
197+
)
198+
if err != nil {
199+
logger.Error("failed to log event", log.Error(err))
200+
}
201+
}
202+
170203
response.JSONError(logger, w, status, err)
171204
return
172205
}
@@ -228,17 +261,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
228261
attribute.Int("resolvedStatusCode", resolvedStatusCode))
229262
}
230263
if flaggingResult.IsFlagged() {
231-
// keep this for backwards-compatibility of abuse data
232-
requestMetadata["flagged"] = true
233-
flaggingMetadata := map[string]any{
234-
"reason": flaggingResult.reasons,
235-
"blocked": flaggingResult.blocked,
236-
}
237-
// only record prompt prefixes for .com actors
238-
if act.IsDotComActor() {
239-
flaggingMetadata["promptPrefix"] = flaggingResult.promptPrefix
240-
}
241-
requestMetadata["flagging_result"] = flaggingMetadata
264+
requestMetadata = mergeMaps(requestMetadata, getFlaggingMetadata(flaggingResult, act))
242265
}
243266
usageData := map[string]any{
244267
"prompt_character_count": promptUsage.characters,
@@ -357,6 +380,22 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
357380
}))
358381
}
359382

383+
func getFlaggingMetadata(flaggingResult *flaggingResult, act *actor.Actor) map[string]any {
384+
requestMetadata := map[string]any{}
385+
386+
requestMetadata["flagged"] = true
387+
flaggingMetadata := map[string]any{
388+
"reason": flaggingResult.reasons,
389+
"should_block": flaggingResult.shouldBlock,
390+
}
391+
392+
if act.IsDotComActor() {
393+
flaggingMetadata["prompt_prefix"] = flaggingResult.promptPrefix
394+
}
395+
requestMetadata["flagging_result"] = flaggingMetadata
396+
return requestMetadata
397+
}
398+
360399
func isAllowedModel(allowedModels []string, model string) bool {
361400
for _, m := range allowedModels {
362401
if strings.EqualFold(m, model) {
@@ -383,3 +422,15 @@ func mergeMaps(dst map[string]any, srcs ...map[string]any) map[string]any {
383422
}
384423
return dst
385424
}
425+
426+
type flaggingResult struct {
427+
shouldBlock bool
428+
reasons []string
429+
promptPrefix string
430+
maxTokensToSample int
431+
promptTokenCount int
432+
}
433+
434+
func (f *flaggingResult) IsFlagged() bool {
435+
return f != nil
436+
}

cmd/cody-gateway/internal/httpapi/handler.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ import (
2525
)
2626

2727
type Config struct {
28-
RateLimitNotifier notify.RateLimitNotifier
29-
AnthropicAccessToken string
30-
AnthropicAllowedModels []string
31-
AnthropicAllowedPromptPatterns []string
32-
AnthropicMaxTokensToSample int
33-
OpenAIAccessToken string
34-
OpenAIOrgID string
35-
OpenAIAllowedModels []string
36-
FireworksAccessToken string
37-
FireworksAllowedModels []string
38-
EmbeddingsAllowedModels []string
28+
RateLimitNotifier notify.RateLimitNotifier
29+
AnthropicAccessToken string
30+
AnthropicAllowedModels []string
31+
AnthropicAllowedPromptPatterns []string
32+
AnthropicRequestBlockingEnabled bool
33+
AnthropicMaxTokensToSample int
34+
OpenAIAccessToken string
35+
OpenAIOrgID string
36+
OpenAIAllowedModels []string
37+
FireworksAccessToken string
38+
FireworksAllowedModels []string
39+
EmbeddingsAllowedModels []string
3940
}
4041

4142
var meter = otel.GetMeterProvider().Meter("cody-gateway/internal/httpapi")
@@ -82,6 +83,7 @@ func NewHandler(
8283
config.AnthropicMaxTokensToSample,
8384
promptRecorder,
8485
config.AnthropicAllowedPromptPatterns,
86+
config.AnthropicRequestBlockingEnabled,
8587
)
8688
if err != nil {
8789
return nil, errors.Wrap(err, "init Anthropic handler")

cmd/cody-gateway/shared/config.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ type Config struct {
2828
}
2929

3030
Anthropic struct {
31-
AllowedModels []string
32-
AccessToken string
33-
MaxTokensToSample int
34-
AllowedPromptPatterns []string
31+
AllowedModels []string
32+
AccessToken string
33+
MaxTokensToSample int
34+
AllowedPromptPatterns []string
35+
RequestBlockingEnabled bool
3536
}
3637

3738
OpenAI struct {
@@ -113,6 +114,7 @@ func (c *Config) Load() {
113114
}
114115
c.Anthropic.MaxTokensToSample = c.GetInt("CODY_GATEWAY_ANTHROPIC_MAX_TOKENS_TO_SAMPLE", "10000", "Maximum permitted value of maxTokensToSample")
115116
c.Anthropic.AllowedPromptPatterns = splitMaybe(c.GetOptional("CODY_GATEWAY_ANTHROPIC_ALLOWED_PROMPT_PATTERNS", "Prompt patterns to allow."))
117+
c.Anthropic.RequestBlockingEnabled = c.GetBool("CODY_GATEWAY_ANTHROPIC_REQUEST_BLOCKING_ENABLED", "false", "Whether we should block requests that match our blocking criteria.")
116118

117119
c.OpenAI.AccessToken = c.GetOptional("CODY_GATEWAY_OPENAI_ACCESS_TOKEN", "The OpenAI access token to be used.")
118120
c.OpenAI.OrgID = c.GetOptional("CODY_GATEWAY_OPENAI_ORG_ID", "The OpenAI organization to count billing towards. Setting this ensures we always use the correct negotiated terms.")

cmd/cody-gateway/shared/main.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,18 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
149149
redis: redispool.Cache,
150150
},
151151
&httpapi.Config{
152-
RateLimitNotifier: rateLimitNotifier,
153-
AnthropicAccessToken: config.Anthropic.AccessToken,
154-
AnthropicAllowedModels: config.Anthropic.AllowedModels,
155-
AnthropicMaxTokensToSample: config.Anthropic.MaxTokensToSample,
156-
AnthropicAllowedPromptPatterns: config.Anthropic.AllowedPromptPatterns,
157-
OpenAIAccessToken: config.OpenAI.AccessToken,
158-
OpenAIOrgID: config.OpenAI.OrgID,
159-
OpenAIAllowedModels: config.OpenAI.AllowedModels,
160-
FireworksAccessToken: config.Fireworks.AccessToken,
161-
FireworksAllowedModels: config.Fireworks.AllowedModels,
162-
EmbeddingsAllowedModels: config.AllowedEmbeddingsModels,
152+
RateLimitNotifier: rateLimitNotifier,
153+
AnthropicAccessToken: config.Anthropic.AccessToken,
154+
AnthropicAllowedModels: config.Anthropic.AllowedModels,
155+
AnthropicMaxTokensToSample: config.Anthropic.MaxTokensToSample,
156+
AnthropicAllowedPromptPatterns: config.Anthropic.AllowedPromptPatterns,
157+
AnthropicRequestBlockingEnabled: config.Anthropic.RequestBlockingEnabled,
158+
OpenAIAccessToken: config.OpenAI.AccessToken,
159+
OpenAIOrgID: config.OpenAI.OrgID,
160+
OpenAIAllowedModels: config.OpenAI.AllowedModels,
161+
FireworksAccessToken: config.Fireworks.AccessToken,
162+
FireworksAllowedModels: config.Fireworks.AllowedModels,
163+
EmbeddingsAllowedModels: config.AllowedEmbeddingsModels,
163164
})
164165
if err != nil {
165166
return errors.Wrap(err, "httpapi.NewHandler")

internal/codygateway/consts.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const (
2020
EventNameRateLimited EventName = "RateLimited"
2121
EventNameCompletionsFinished EventName = "CompletionsFinished"
2222
EventNameEmbeddingsFinished EventName = "EmbeddingsFinished"
23+
EventNameRequestBlocked EventName = "RequestBlocked"
2324
)
2425

2526
const FeatureHeaderName = "X-Sourcegraph-Feature"

0 commit comments

Comments
 (0)