Skip to content

Commit 305f949

Browse files
authored
feat(gen-ai): add feature flag and mms prompts COMPASS-10082 (#7598)
* add feature flag * migrate prompts to compass * co-pilot feedback * clean up * add tests and codefence * use toJsString * fix check
1 parent c4052c2 commit 305f949

File tree

5 files changed

+331
-0
lines changed

5 files changed

+331
-0
lines changed

package-lock.json

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/compass-generative-ai/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"depcheck": "^1.4.1",
8686
"electron-mocha": "^12.2.0",
8787
"mocha": "^10.2.0",
88+
"mongodb-query-parser": "^4.5.0",
8889
"nyc": "^15.1.0",
8990
"p-queue": "^7.4.1",
9091
"sinon": "^9.2.3",
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import { expect } from 'chai';
2+
import {
3+
buildFindQueryPrompt,
4+
buildAggregateQueryPrompt,
5+
type UserPromptForQueryOptions,
6+
} from './gen-ai-prompt';
7+
import { toJSString } from 'mongodb-query-parser';
8+
import { ObjectId } from 'bson';
9+
10+
const OPTIONS: UserPromptForQueryOptions = {
11+
userPrompt: 'Find all users older than 30',
12+
databaseName: 'airbnb',
13+
collectionName: 'listings',
14+
schema: {
15+
_id: {
16+
types: [
17+
{
18+
bsonType: 'ObjectId',
19+
},
20+
],
21+
},
22+
userId: {
23+
types: [
24+
{
25+
bsonType: 'ObjectId',
26+
},
27+
],
28+
},
29+
},
30+
sampleDocuments: [
31+
{
32+
_id: new ObjectId('68a2dfe93d5adb16ebf4c866'),
33+
userId: new ObjectId('68a2dfe93d5adb16ebf4c865'),
34+
},
35+
],
36+
};
37+
38+
describe('GenAI Prompts', function () {
39+
it('buildFindQueryPrompt', function () {
40+
const {
41+
prompt,
42+
metadata: { instructions },
43+
} = buildFindQueryPrompt(OPTIONS);
44+
45+
expect(instructions).to.be.a('string');
46+
expect(instructions).to.include(
47+
'The current date is',
48+
'includes date instruction'
49+
);
50+
51+
expect(prompt).to.be.a('string');
52+
expect(prompt).to.include(
53+
`Write a query that does the following: "${OPTIONS.userPrompt}"`,
54+
'includes user prompt'
55+
);
56+
expect(prompt).to.include(
57+
`Database name: "${OPTIONS.databaseName}"`,
58+
'includes database name'
59+
);
60+
expect(prompt).to.include(
61+
`Collection name: "${OPTIONS.collectionName}"`,
62+
'includes collection name'
63+
);
64+
expect(prompt).to.include(
65+
'Schema from a sample of documents from the collection:',
66+
'includes schema text'
67+
);
68+
expect(prompt).to.include(
69+
toJSString(OPTIONS.schema),
70+
'includes actual schema'
71+
);
72+
expect(prompt).to.include(
73+
'Sample documents from the collection:',
74+
'includes sample documents text'
75+
);
76+
expect(prompt).to.include(
77+
toJSString(OPTIONS.sampleDocuments),
78+
'includes actual sample documents'
79+
);
80+
});
81+
82+
it('buildAggregateQueryPrompt', function () {
83+
const {
84+
prompt,
85+
metadata: { instructions },
86+
} = buildAggregateQueryPrompt(OPTIONS);
87+
88+
expect(instructions).to.be.a('string');
89+
expect(instructions).to.include(
90+
'The current date is',
91+
'includes date instruction'
92+
);
93+
94+
expect(prompt).to.be.a('string');
95+
expect(prompt).to.include(
96+
`Generate an aggregation that does the following: "${OPTIONS.userPrompt}"`,
97+
'includes user prompt'
98+
);
99+
expect(prompt).to.include(
100+
`Database name: "${OPTIONS.databaseName}"`,
101+
'includes database name'
102+
);
103+
expect(prompt).to.include(
104+
`Collection name: "${OPTIONS.collectionName}"`,
105+
'includes collection name'
106+
);
107+
expect(prompt).to.include(
108+
'Schema from a sample of documents from the collection:',
109+
'includes schema text'
110+
);
111+
expect(prompt).to.include(
112+
toJSString(OPTIONS.schema),
113+
'includes actual schema'
114+
);
115+
expect(prompt).to.include(
116+
'Sample documents from the collection:',
117+
'includes sample documents text'
118+
);
119+
expect(prompt).to.include(
120+
toJSString(OPTIONS.sampleDocuments),
121+
'includes actual sample documents'
122+
);
123+
});
124+
});
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import { toJSString } from 'mongodb-query-parser';
2+
3+
// When including sample documents, we want to ensure that we do not
4+
// attach large documents and exceed the limit. OpenAI roughly estimates
5+
// 4 characters = 1 token and we should not exceed context window limits.
6+
// This roughly translates to 128k tokens.
7+
// TODO(COMPASS-10129): Adjust this limit based on the model's context window.
8+
const MAX_TOTAL_PROMPT_LENGTH = 512000;
9+
const MIN_SAMPLE_DOCUMENTS = 1;
10+
11+
function getCurrentTimeString() {
12+
const dateTime = new Date();
13+
const options: Intl.DateTimeFormatOptions = {
14+
weekday: 'short',
15+
year: 'numeric',
16+
month: 'short',
17+
day: '2-digit',
18+
hour: '2-digit',
19+
minute: '2-digit',
20+
second: '2-digit',
21+
timeZoneName: 'short',
22+
hour12: false,
23+
};
24+
// e.g. Tue, Nov 25, 2025, 12:00:00 GMT+1
25+
return dateTime.toLocaleString('en-US', options);
26+
}
27+
28+
function buildInstructionsForFindQuery() {
29+
return [
30+
'Reduce prose to the minimum, your output will be parsed by a machine. ' +
31+
'You generate MongoDB find query arguments. Provide filter, project, sort, skip, ' +
32+
'limit and aggregation in shell syntax, wrap each argument with XML delimiters as follows:',
33+
'<filter>{}</filter>',
34+
'<project>{}</project>',
35+
'<sort>{}</sort>',
36+
'<skip>0</skip>',
37+
'<limit>0</limit>',
38+
'<aggregation>[]</aggregation>',
39+
'Additional instructions:',
40+
'- Only use the aggregation field when the request cannot be represented with the other fields.',
41+
'- Do not use the aggregation field if a find query fulfills the objective.',
42+
'- If specifying latitude and longitude coordinates, list the longitude first, and then latitude.',
43+
`- The current date is ${getCurrentTimeString()}`,
44+
].join('\n');
45+
}
46+
47+
function buildInstructionsForAggregateQuery() {
48+
return [
49+
'Reduce prose to the minimum, your output will be parsed by a machine. ' +
50+
'You generate MongoDB aggregation pipelines. Provide only the aggregation ' +
51+
'pipeline contents in an array in shell syntax, wrapped with XML delimiters as follows:',
52+
'<aggregation>[]</aggregation>',
53+
'Additional instructions:',
54+
'- If specifying latitude and longitude coordinates, list the longitude first, and then latitude.',
55+
'- Only pass the contents of the aggregation, no surrounding syntax.',
56+
`- The current date is ${getCurrentTimeString()}`,
57+
].join('\n');
58+
}
59+
60+
export type UserPromptForQueryOptions = {
61+
userPrompt: string;
62+
databaseName?: string;
63+
collectionName?: string;
64+
schema?: unknown;
65+
sampleDocuments?: unknown[];
66+
};
67+
68+
function withCodeFence(code: string): string {
69+
return [
70+
'', // Line break
71+
'```',
72+
code,
73+
'```',
74+
].join('\n');
75+
}
76+
77+
function buildUserPromptForQuery({
78+
type,
79+
userPrompt,
80+
databaseName,
81+
collectionName,
82+
schema,
83+
sampleDocuments,
84+
}: UserPromptForQueryOptions & { type: 'find' | 'aggregate' }): string {
85+
const messages = [];
86+
87+
const queryPrompt = [
88+
type === 'find' ? 'Write a query' : 'Generate an aggregation',
89+
'that does the following:',
90+
`"${userPrompt}"`,
91+
].join(' ');
92+
93+
if (databaseName) {
94+
messages.push(`Database name: "${databaseName}"`);
95+
}
96+
if (collectionName) {
97+
messages.push(`Collection name: "${collectionName}"`);
98+
}
99+
if (schema) {
100+
messages.push(
101+
`Schema from a sample of documents from the collection:${withCodeFence(
102+
toJSString(schema)!
103+
)}`
104+
);
105+
}
106+
if (sampleDocuments) {
107+
// When attaching the sample documents, we want to ensure that we do not
108+
// exceed the token limit. So we try following:
109+
// 1. If attaching all the sample documents exceeds then limit, we attach only 1 document.
110+
// 2. If attaching 1 document still exceeds the limit, we do not attach any sample documents.
111+
const sampleDocumentsStr = toJSString(sampleDocuments);
112+
const singleDocumentStr = toJSString(
113+
sampleDocuments.slice(0, MIN_SAMPLE_DOCUMENTS)
114+
);
115+
const promptLengthWithoutSampleDocs =
116+
messages.join('\n').length + queryPrompt.length;
117+
if (
118+
sampleDocumentsStr &&
119+
sampleDocumentsStr.length + promptLengthWithoutSampleDocs <=
120+
MAX_TOTAL_PROMPT_LENGTH
121+
) {
122+
messages.push(
123+
`Sample documents from the collection:${withCodeFence(
124+
sampleDocumentsStr
125+
)}`
126+
);
127+
} else if (
128+
singleDocumentStr &&
129+
singleDocumentStr.length + promptLengthWithoutSampleDocs <=
130+
MAX_TOTAL_PROMPT_LENGTH
131+
) {
132+
messages.push(
133+
`Sample document from the collection:${withCodeFence(
134+
singleDocumentStr
135+
)}`
136+
);
137+
}
138+
}
139+
messages.push(queryPrompt);
140+
return messages.join('\n');
141+
}
142+
143+
export type AiQueryPrompt = {
144+
prompt: string;
145+
metadata: {
146+
instructions: string;
147+
};
148+
};
149+
150+
export function buildFindQueryPrompt({
151+
userPrompt,
152+
databaseName,
153+
collectionName,
154+
schema,
155+
sampleDocuments,
156+
}: UserPromptForQueryOptions): AiQueryPrompt {
157+
const prompt = buildUserPromptForQuery({
158+
type: 'find',
159+
userPrompt,
160+
databaseName,
161+
collectionName,
162+
schema,
163+
sampleDocuments,
164+
});
165+
const instructions = buildInstructionsForFindQuery();
166+
return {
167+
prompt,
168+
metadata: {
169+
instructions,
170+
},
171+
};
172+
}
173+
174+
export function buildAggregateQueryPrompt({
175+
userPrompt,
176+
databaseName,
177+
collectionName,
178+
schema,
179+
sampleDocuments,
180+
}: UserPromptForQueryOptions): AiQueryPrompt {
181+
const prompt = buildUserPromptForQuery({
182+
type: 'aggregate',
183+
userPrompt,
184+
databaseName,
185+
collectionName,
186+
schema,
187+
sampleDocuments,
188+
});
189+
const instructions = buildInstructionsForAggregateQuery();
190+
return {
191+
prompt,
192+
metadata: {
193+
instructions,
194+
},
195+
};
196+
}

packages/compass-preferences-model/src/feature-flags.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ export const FEATURE_FLAG_DEFINITIONS = [
224224
'Enable automatic relationship inference during data model generation',
225225
},
226226
},
227+
{
228+
name: 'enableChatbotEndpointForGenAI',
229+
stage: 'development',
230+
atlasCloudFeatureFlagName: null,
231+
description: {
232+
short: 'Enable Chatbot API for Generative AI',
233+
},
234+
},
227235
] as const satisfies ReadonlyArray<FeatureFlagDefinition>;
228236

229237
type FeatureFlagDefinitions = typeof FEATURE_FLAG_DEFINITIONS;

0 commit comments

Comments
 (0)