Skip to content

Commit f531ad6

Browse files
authored
fix(policy): always wrap mutation in a transaction (#276)
* fix(policy): always wrap mutation in a transaction * fix error handling
1 parent 198c352 commit f531ad6

File tree

5 files changed

+131
-190
lines changed

5 files changed

+131
-190
lines changed

packages/runtime/src/client/errors.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
/**
2+
* Base for all ZenStack runtime errors.
3+
*/
4+
export class ZenStackError extends Error {}
5+
16
/**
27
* Error thrown when input validation fails.
38
*/
4-
export class InputValidationError extends Error {
9+
export class InputValidationError extends ZenStackError {
510
constructor(message: string, cause?: unknown) {
611
super(message, { cause });
712
}
@@ -10,7 +15,7 @@ export class InputValidationError extends Error {
1015
/**
1116
* Error thrown when a query fails.
1217
*/
13-
export class QueryError extends Error {
18+
export class QueryError extends ZenStackError {
1419
constructor(message: string, cause?: unknown) {
1520
super(message, { cause });
1621
}
@@ -19,12 +24,12 @@ export class QueryError extends Error {
1924
/**
2025
* Error thrown when an internal error occurs.
2126
*/
22-
export class InternalError extends Error {}
27+
export class InternalError extends ZenStackError {}
2328

2429
/**
2530
* Error thrown when an entity is not found.
2631
*/
27-
export class NotFoundError extends Error {
32+
export class NotFoundError extends ZenStackError {
2833
constructor(model: string, details?: string) {
2934
super(`Entity not found for model "${model}"${details ? `: ${details}` : ''}`);
3035
}

packages/runtime/src/client/executor/zenstack-query-executor.ts

Lines changed: 110 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import { match } from 'ts-pattern';
2525
import type { GetModels, SchemaDef } from '../../schema';
2626
import { type ClientImpl } from '../client-impl';
2727
import { TransactionIsolationLevel, type ClientContract } from '../contract';
28-
import { InternalError, QueryError } from '../errors';
28+
import { InternalError, QueryError, ZenStackError } from '../errors';
2929
import { stripAlias } from '../kysely-utils';
3030
import type { AfterEntityMutationCallback, OnKyselyQueryCallback } from '../plugin';
3131
import { QueryNameMapper } from './name-mapper';
@@ -65,21 +65,53 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
6565
return this.client.$options;
6666
}
6767

68-
override async executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) {
68+
override executeQuery(compiledQuery: CompiledQuery, queryId: QueryId) {
6969
// proceed with the query with kysely interceptors
7070
// if the query is a raw query, we need to carry over the parameters
7171
const queryParams = (compiledQuery as any).$raw ? compiledQuery.parameters : undefined;
72-
const result = await this.proceedQueryWithKyselyInterceptors(compiledQuery.query, queryParams, queryId.queryId);
7372

74-
return result.result;
73+
return this.provideConnection(async (connection) => {
74+
let startedTx = false;
75+
try {
76+
// mutations are wrapped in tx if not already in one
77+
if (this.isMutationNode(compiledQuery.query) && !this.driver.isTransactionConnection(connection)) {
78+
await this.driver.beginTransaction(connection, {
79+
isolationLevel: TransactionIsolationLevel.RepeatableRead,
80+
});
81+
startedTx = true;
82+
}
83+
const result = await this.proceedQueryWithKyselyInterceptors(
84+
connection,
85+
compiledQuery.query,
86+
queryParams,
87+
queryId.queryId,
88+
);
89+
if (startedTx) {
90+
await this.driver.commitTransaction(connection);
91+
}
92+
return result;
93+
} catch (err) {
94+
if (startedTx) {
95+
await this.driver.rollbackTransaction(connection);
96+
}
97+
if (err instanceof ZenStackError) {
98+
throw err;
99+
} else {
100+
// wrap error
101+
const message = `Failed to execute query: ${err}, sql: ${compiledQuery?.sql}`;
102+
throw new QueryError(message, err);
103+
}
104+
}
105+
});
75106
}
76107

77108
private async proceedQueryWithKyselyInterceptors(
109+
connection: DatabaseConnection,
78110
queryNode: RootOperationNode,
79111
parameters: readonly unknown[] | undefined,
80112
queryId: string,
81113
) {
82-
let proceed = (q: RootOperationNode) => this.proceedQuery(q, parameters, queryId);
114+
let proceed = (q: RootOperationNode) => this.proceedQuery(connection, q, parameters, queryId);
83115

84116
const hooks: OnKyselyQueryCallback<Schema>[] = [];
85117
// tsc perf
@@ -92,18 +124,14 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
92124
for (const hook of hooks) {
93125
const _proceed = proceed;
94126
proceed = async (query: RootOperationNode) => {
95-
const _p = async (q: RootOperationNode) => {
96-
const r = await _proceed(q);
97-
return r.result;
98-
};
99-
127+
const _p = (q: RootOperationNode) => _proceed(q);
100128
const hookResult = await hook!({
101129
client: this.client as ClientContract<Schema>,
102130
schema: this.client.$schema,
103131
query,
104132
proceed: _p,
105133
});
106-
return { result: hookResult };
134+
return hookResult;
107135
};
108136
}
109137

@@ -132,161 +160,83 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
132160
return { model, action, where };
133161
}
134162

135-
private async proceedQuery(query: RootOperationNode, parameters: readonly unknown[] | undefined, queryId: string) {
163+
private async proceedQuery(
164+
connection: DatabaseConnection,
165+
query: RootOperationNode,
166+
parameters: readonly unknown[] | undefined,
167+
queryId: string,
168+
) {
136169
let compiled: CompiledQuery | undefined;
137170

138-
try {
139-
return await this.provideConnection(async (connection) => {
140-
if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
141-
// no need to handle mutation hooks, just proceed
142-
const finalQuery = this.nameMapper.transformNode(query);
143-
compiled = this.compileQuery(finalQuery);
144-
if (parameters) {
145-
compiled = { ...compiled, parameters };
146-
}
147-
const result = await connection.executeQuery<any>(compiled);
148-
return { result };
149-
}
171+
if (this.suppressMutationHooks || !this.isMutationNode(query) || !this.hasEntityMutationPlugins) {
172+
// no need to handle mutation hooks, just proceed
173+
const finalQuery = this.nameMapper.transformNode(query);
174+
compiled = this.compileQuery(finalQuery);
175+
if (parameters) {
176+
compiled = { ...compiled, parameters };
177+
}
178+
return connection.executeQuery<any>(compiled);
179+
}
150180

151-
if (
152-
(InsertQueryNode.is(query) || UpdateQueryNode.is(query)) &&
153-
this.hasEntityMutationPluginsWithAfterMutationHooks
154-
) {
155-
// need to make sure the query node has "returnAll" for insert and update queries
156-
// so that after-mutation hooks can get the mutated entities with all fields
157-
query = {
158-
...query,
159-
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
160-
};
161-
}
162-
const finalQuery = this.nameMapper.transformNode(query);
163-
compiled = this.compileQuery(finalQuery);
164-
if (parameters) {
165-
compiled = { ...compiled, parameters };
166-
}
181+
if (
182+
(InsertQueryNode.is(query) || UpdateQueryNode.is(query)) &&
183+
this.hasEntityMutationPluginsWithAfterMutationHooks
184+
) {
185+
// need to make sure the query node has "returnAll" for insert and update queries
186+
// so that after-mutation hooks can get the mutated entities with all fields
187+
query = {
188+
...query,
189+
returning: ReturningNode.create([SelectionNode.createSelectAll()]),
190+
};
191+
}
192+
const finalQuery = this.nameMapper.transformNode(query);
193+
compiled = this.compileQuery(finalQuery);
194+
if (parameters) {
195+
compiled = { ...compiled, parameters };
196+
}
167197

168-
// the client passed to hooks needs to be in sync with current in-transaction
169-
// status so that it doesn't try to create a nested one
170-
const currentlyInTx = this.driver.isTransactionConnection(connection);
171-
172-
const connectionClient = this.createClientForConnection(connection, currentlyInTx);
173-
174-
const mutationInfo = this.getMutationInfo(finalQuery);
175-
176-
// cache already loaded before-mutation entities
177-
let beforeMutationEntities: Record<string, unknown>[] | undefined;
178-
const loadBeforeMutationEntities = async () => {
179-
if (
180-
beforeMutationEntities === undefined &&
181-
(UpdateQueryNode.is(query) || DeleteQueryNode.is(query))
182-
) {
183-
beforeMutationEntities = await this.loadEntities(
184-
mutationInfo.model,
185-
mutationInfo.where,
186-
connection,
187-
);
188-
}
189-
return beforeMutationEntities;
190-
};
191-
192-
// call before mutation hooks
193-
await this.callBeforeMutationHooks(
194-
finalQuery,
195-
mutationInfo,
196-
loadBeforeMutationEntities,
197-
connectionClient,
198-
queryId,
199-
);
198+
// the client passed to hooks needs to be in sync with current in-transaction
199+
// status so that it doesn't try to create a nested one
200+
const currentlyInTx = this.driver.isTransactionConnection(connection);
200201

201-
// if mutation interceptor demands to run afterMutation hook in the transaction but we're not already
202-
// inside one, we need to create one on the fly
203-
const shouldCreateTx =
204-
this.hasPluginRequestingAfterMutationWithinTransaction &&
205-
!this.driver.isTransactionConnection(connection);
206-
207-
if (!shouldCreateTx) {
208-
// if no on-the-fly tx is needed, just proceed with the query as is
209-
const result = await connection.executeQuery<any>(compiled);
210-
211-
if (!this.driver.isTransactionConnection(connection)) {
212-
// not in a transaction, just call all after-mutation hooks
213-
await this.callAfterMutationHooks(
214-
result,
215-
finalQuery,
216-
mutationInfo,
217-
connectionClient,
218-
'all',
219-
queryId,
220-
);
221-
} else {
222-
// run after-mutation hooks that are requested to be run inside tx
223-
await this.callAfterMutationHooks(
224-
result,
225-
finalQuery,
226-
mutationInfo,
227-
connectionClient,
228-
'inTx',
229-
queryId,
230-
);
231-
232-
// register other after-mutation hooks to be run after the tx is committed
233-
this.driver.registerTransactionCommitCallback(connection, () =>
234-
this.callAfterMutationHooks(
235-
result,
236-
finalQuery,
237-
mutationInfo,
238-
connectionClient,
239-
'outTx',
240-
queryId,
241-
),
242-
);
243-
}
244-
245-
return { result };
246-
} else {
247-
// if an on-the-fly tx is created, create one and wrap the query execution inside
248-
await this.driver.beginTransaction(connection, {
249-
isolationLevel: TransactionIsolationLevel.ReadCommitted,
250-
});
251-
try {
252-
// execute the query inside the on-the-fly transaction
253-
const result = await connection.executeQuery<any>(compiled);
254-
255-
// run after-mutation hooks that are requested to be run inside tx
256-
await this.callAfterMutationHooks(
257-
result,
258-
finalQuery,
259-
mutationInfo,
260-
connectionClient,
261-
'inTx',
262-
queryId,
263-
);
264-
265-
// commit the transaction
266-
await this.driver.commitTransaction(connection);
267-
268-
// run other after-mutation hooks after the tx is committed
269-
await this.callAfterMutationHooks(
270-
result,
271-
finalQuery,
272-
mutationInfo,
273-
connectionClient,
274-
'outTx',
275-
queryId,
276-
);
277-
278-
return { result };
279-
} catch (err) {
280-
// rollback the transaction
281-
await this.driver.rollbackTransaction(connection);
282-
throw err;
283-
}
284-
}
285-
});
286-
} catch (err) {
287-
const message = `Failed to execute query: ${err}, sql: ${compiled?.sql}`;
288-
throw new QueryError(message, err);
202+
const connectionClient = this.createClientForConnection(connection, currentlyInTx);
203+
204+
const mutationInfo = this.getMutationInfo(finalQuery);
205+
206+
// cache already loaded before-mutation entities
207+
let beforeMutationEntities: Record<string, unknown>[] | undefined;
208+
const loadBeforeMutationEntities = async () => {
209+
if (beforeMutationEntities === undefined && (UpdateQueryNode.is(query) || DeleteQueryNode.is(query))) {
210+
beforeMutationEntities = await this.loadEntities(mutationInfo.model, mutationInfo.where, connection);
211+
}
212+
return beforeMutationEntities;
213+
};
214+
215+
// call before mutation hooks
216+
await this.callBeforeMutationHooks(
217+
finalQuery,
218+
mutationInfo,
219+
loadBeforeMutationEntities,
220+
connectionClient,
221+
queryId,
222+
);
223+
224+
const result = await connection.executeQuery<any>(compiled);
225+
226+
if (!this.driver.isTransactionConnection(connection)) {
227+
// not in a transaction, just call all after-mutation hooks
228+
await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'all', queryId);
229+
} else {
230+
// run after-mutation hooks that are requested to be run inside tx
231+
await this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'inTx', queryId);
232+
233+
// register other after-mutation hooks to be run after the tx is committed
234+
this.driver.registerTransactionCommitCallback(connection, () =>
235+
this.callAfterMutationHooks(result, finalQuery, mutationInfo, connectionClient, 'outTx', queryId),
236+
);
289237
}
238+
239+
return result;
290240
}
291241

292242
private createClientForConnection(connection: DatabaseConnection, inTx: boolean) {
@@ -307,12 +257,6 @@ export class ZenStackQueryExecutor<Schema extends SchemaDef> extends DefaultQuer
307257
return (this.client.$options.plugins ?? []).some((plugin) => plugin.onEntityMutation?.afterEntityMutation);
308258
}
309259

310-
private get hasPluginRequestingAfterMutationWithinTransaction() {
311-
return (this.client.$options.plugins ?? []).some(
312-
(plugin) => plugin.onEntityMutation?.runAfterMutationWithinTransaction,
313-
);
314-
}
315-
316260
private isMutationNode(queryNode: RootOperationNode): queryNode is MutationQueryNode {
317261
return InsertQueryNode.is(queryNode) || UpdateQueryNode.is(queryNode) || DeleteQueryNode.is(queryNode);
318262
}

packages/runtime/src/plugins/policy/errors.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { ZenStackError } from '../../client';
2+
13
/**
24
* Reason code for policy rejection.
35
*/
@@ -21,7 +23,7 @@ export enum RejectedByPolicyReason {
2123
/**
2224
* Error thrown when an operation is rejected by access policy.
2325
*/
24-
export class RejectedByPolicyError extends Error {
26+
export class RejectedByPolicyError extends ZenStackError {
2527
constructor(
2628
public readonly model: string | undefined,
2729
public readonly reason: RejectedByPolicyReason = RejectedByPolicyReason.NO_ACCESS,

0 commit comments

Comments
 (0)