Skip to content

Commit 65433cf

Browse files
committed
Caching - handle empty query and prepared statement to support Hibernate integration.
1 parent c12d3ed commit 65433cf

File tree

2 files changed

+145
-5
lines changed

2 files changed

+145
-5
lines changed

wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.sql.Connection;
2020
import java.sql.DatabaseMetaData;
21+
import java.sql.PreparedStatement;
2122
import java.sql.ResultSet;
2223
import java.sql.SQLException;
2324
import java.util.Arrays;
@@ -43,7 +44,7 @@
4344

4445
public class DataRemoteCachePlugin extends AbstractConnectionPlugin {
4546
private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName());
46-
private static final String QUERY_HINT_START_PATTERN = "/*+";
47+
private static final String QUERY_HINT_START_PATTERN = "/*";
4748
private static final String QUERY_HINT_END_PATTERN = "*/";
4849
private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM(";
4950
private static final String TELEMETRY_CACHE_LOOKUP = "jdbc-cache-lookup";
@@ -260,7 +261,7 @@ public <T, E extends Exception> T execute(
260261

261262
ResultSet result;
262263
boolean needToCache = false;
263-
final String sql = getQuery(jdbcMethodArgs);
264+
final String sql = getQuery(methodInvokeOn, jdbcMethodArgs);
264265

265266
TelemetryContext cacheContext = null;
266267
TelemetryContext dbContext = null;
@@ -271,7 +272,7 @@ public <T, E extends Exception> T execute(
271272
int endOfQueryHint = 0;
272273
Integer configuredQueryTtl = null;
273274
// Queries longer than 16KB is not cacheable
274-
if ((sql.length() < maxCacheableQuerySize) && sql.startsWith(QUERY_HINT_START_PATTERN)) {
275+
if (!StringUtils.isNullOrEmpty(sql) && (sql.length() < maxCacheableQuerySize) && sql.contains(QUERY_HINT_START_PATTERN)) {
275276
endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN);
276277
if (endOfQueryHint > 0) {
277278
configuredQueryTtl = getTtlForQuery(sql.substring(QUERY_HINT_START_PATTERN.length(), endOfQueryHint).trim());
@@ -355,11 +356,20 @@ private void incrCounter(TelemetryCounter counter) {
355356
counter.inc();
356357
}
357358

358-
protected String getQuery(final Object[] jdbcMethodArgs) {
359+
protected String getQuery(final Object methodInvokeOn, final Object[] jdbcMethodArgs) {
359360
// Get query from method argument
360361
if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) {
361362
return jdbcMethodArgs[0].toString().trim();
362363
}
364+
365+
// If the query is not in the method arguments, check for prepared statement query. Get the query
366+
// string from the prepared statement. The exact query string is dependent on the underlying driver.
367+
if (methodInvokeOn instanceof PreparedStatement) {
368+
// For postgres, this gives the raw query itself. i.e. "select * from T where A = 1".
369+
// For MySQL, this gives "com.mysql.cj.jdbc.ClientPreparedStatement: select * from T where A = 1"
370+
// For MariaDB, this gives "ClientPreparedStatement{sql:'select * from T where A=1', parameters:[]}"
371+
return methodInvokeOn.toString();
372+
}
363373
return null;
364374
}
365375
}

wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class DataRemoteCachePluginTest {
4343
@Mock TelemetryContext mockTelemetryContext;
4444
@Mock ResultSet mockResult1;
4545
@Mock Statement mockStatement;
46+
@Mock PreparedStatement mockPreparedStatement;
4647
@Mock ResultSetMetaData mockMetaData;
4748
@Mock Connection mockConnection;
4849
@Mock SessionStateService mockSessionStateService;
@@ -170,6 +171,67 @@ void test_execute_noCaching() throws Exception {
170171
verify(mockTelemetryContext).closeContext();
171172
}
172173

174+
@Test
175+
void test_execute_emptyQuery_noCaching() throws Exception {
176+
plugin = new DataRemoteCachePlugin(mockPluginService, props);
177+
plugin.setCacheConnection(mockCacheConn);
178+
// Query is not cacheable
179+
when(mockPluginService.isInTransaction()).thenReturn(false);
180+
when(mockCallable.call()).thenReturn(mockResult1);
181+
182+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
183+
methodName, mockCallable, new String[]{});
184+
185+
// Mock result set containing 1 row
186+
when(mockResult1.next()).thenReturn(true, true, false);
187+
when(mockResult1.getObject(1)).thenReturn("bar1", "bar1");
188+
compareResults(mockResult1, rs);
189+
verify(mockPluginService).isInTransaction();
190+
verify(mockCallable).call();
191+
verify(mockTotalQueryCounter, times(1)).inc();
192+
verify(mockCacheHitCounter, never()).inc();
193+
verify(mockCacheBypassCounter, times(1)).inc();
194+
verify(mockCacheMissCounter, never()).inc();
195+
// Verify TelemetryContext behavior for no-caching scenario
196+
verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL);
197+
verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any());
198+
verify(mockTelemetryContext).closeContext();
199+
}
200+
201+
@Test
202+
void test_execute_emptyPreparedStatement_noCaching() throws Exception {
203+
plugin = new DataRemoteCachePlugin(mockPluginService, props);
204+
plugin.setCacheConnection(mockCacheConn);
205+
// Query is not cacheable
206+
when(mockPluginService.isInTransaction()).thenReturn(false);
207+
when(mockPreparedStatement.toString()).thenReturn("", null);
208+
when(mockCallable.call()).thenReturn(mockResult1);
209+
210+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement,
211+
methodName, mockCallable, new String[]{});
212+
// Mock result set containing 1 row
213+
when(mockResult1.next()).thenReturn(true, true, false, true, true, false);
214+
when(mockResult1.getObject(1)).thenReturn("bar1", "bar1", "bar1", "bar1");
215+
compareResults(mockResult1, rs);
216+
217+
rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement,
218+
methodName, mockCallable, new String[]{});
219+
// Mock result set containing 1 row
220+
compareResults(mockResult1, rs);
221+
222+
verify(mockPluginService, times(2)).isInTransaction();
223+
verify(mockCallable, times(2)).call();
224+
verify(mockTotalQueryCounter, times(2)).inc();
225+
verify(mockCacheHitCounter, never()).inc();
226+
verify(mockCacheBypassCounter, times(2)).inc();
227+
verify(mockCacheMissCounter, never()).inc();
228+
// Verify TelemetryContext behavior for no-caching scenario
229+
verify(mockTelemetryFactory, times(2)).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL);
230+
verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any());
231+
//verify(mockPreparedStatement, times(2)).toString();
232+
verify(mockTelemetryContext, times(2)).closeContext();
233+
}
234+
173235
@Test
174236
void test_execute_noCachingLongQuery() throws Exception {
175237
plugin = new DataRemoteCachePlugin(mockPluginService, props);
@@ -179,7 +241,7 @@ void test_execute_noCachingLongQuery() throws Exception {
179241
when(mockCallable.call()).thenReturn(mockResult1);
180242

181243
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
182-
methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)});
244+
methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T " + RandomStringUtils.randomAlphanumeric(16350)});
183245

184246
// Mock result set containing 1 row
185247
when(mockResult1.next()).thenReturn(true, true, false);
@@ -262,6 +324,74 @@ void test_execute_cachingMissAndHit() throws Exception {
262324
verify(mockTelemetryContext, times(3)).closeContext();
263325
}
264326

327+
@Test
328+
void test_cachingMissAndHit_preparedStatement() throws Exception {
329+
plugin = new DataRemoteCachePlugin(mockPluginService, props);
330+
plugin.setCacheConnection(mockCacheConn);
331+
// Query is a cache miss
332+
when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection);
333+
when(mockPluginService.isInTransaction()).thenReturn(false);
334+
when(mockConnection.getMetaData()).thenReturn(mockDbMetadata);
335+
when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService);
336+
when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql"));
337+
when(mockSessionStateService.getSchema()).thenReturn(Optional.empty());
338+
when(mockConnection.getCatalog()).thenReturn("mysql");
339+
when(mockConnection.getSchema()).thenReturn(null);
340+
when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1");
341+
when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null);
342+
when(mockCallable.call()).thenReturn(mockResult1);
343+
344+
// Result set contains 1 row
345+
when(mockResult1.next()).thenReturn(true, false);
346+
when(mockResult1.getObject(1)).thenReturn("bar1");
347+
when(mockPreparedStatement.toString()).thenReturn("/* CACHE_PARAM(ttl=50s) */ select * from A");
348+
349+
// Now query is a cache hit
350+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement,
351+
methodName, mockCallable, new String[]{});
352+
353+
// Cached result set contains 1 row
354+
assertTrue(rs.next());
355+
assertEquals("bar1", rs.getString("fooName"));
356+
assertFalse(rs.next());
357+
358+
rs.beforeFirst();
359+
byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray();
360+
when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet);
361+
362+
ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement,
363+
methodName, mockCallable, new String[]{});
364+
365+
assertTrue(rs2.next());
366+
assertEquals("bar1", rs2.getString("fooName"));
367+
assertFalse(rs2.next());
368+
verify(mockPluginService, times(3)).getCurrentConnection();
369+
verify(mockPluginService, times(2)).isInTransaction();
370+
verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A");
371+
verify(mockPluginService, times(3)).getSessionStateService();
372+
verify(mockSessionStateService, times(3)).getCatalog();
373+
verify(mockSessionStateService, times(3)).getSchema();
374+
verify(mockConnection).getCatalog();
375+
verify(mockConnection).getSchema();
376+
verify(mockSessionStateService).setCatalog("mysql");
377+
verify(mockDbMetadata).getUserName();
378+
verify(mockCallable).call();
379+
verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50));
380+
verify(mockTotalQueryCounter, times(2)).inc();
381+
verify(mockCacheMissCounter, times(1)).inc();
382+
verify(mockCacheHitCounter, times(1)).inc();
383+
verify(mockCacheBypassCounter, never()).inc();
384+
// Verify TelemetryContext behavior for cache miss and hit scenario
385+
// First call: Cache miss + Database call
386+
verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL));
387+
verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL));
388+
// Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true))
389+
verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss
390+
verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit
391+
// Context closure: 2 cache contexts + 1 database context = 3 total
392+
verify(mockTelemetryContext, times(3)).closeContext();
393+
}
394+
265395
@Test
266396
void test_transaction_cacheQuery() throws Exception {
267397
props.setProperty("user", "dbuser");

0 commit comments

Comments
 (0)