Skip to content

Commit 1981b9c

Browse files
Frank-Gu-81QuChen88
authored andcommitted
Implemented query hint feature that supports multiple query parameters
- Replace /* cacheTTL=300s */ with /*+ CACHE_PARAM(ttl=300s) */ format - Add case-insensitive CACHE_PARAM parsing with flexible placement - Implement malformed hint detection with JdbcCacheMalformedQueryHint telemetry - Add comprehensive test coverage for valid/invalid/malformed cases - Support multiple parameters for future extensibility Fixes: Query hint parsing to follow Oracle SQL hint standards feat: add TTL validation for cache query hints - Reject TTL values <= 0 as malformed hints - Allow large TTL values (no upper limit) - Increment malformed hint counter for invalid TTL values - Add comprehensive test coverage for TTL edge cases Validates: ttl=0s, ttl=-10s → not cacheable Allows: ttl=999999s → cache with large TTL updated test cases to reflect new implementation
1 parent fa106dc commit 1981b9c

File tree

2 files changed

+185
-51
lines changed

2 files changed

+185
-51
lines changed

wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import software.amazon.jdbc.JdbcCallable;
3030
import software.amazon.jdbc.JdbcMethod;
3131
import software.amazon.jdbc.PluginService;
32-
import software.amazon.jdbc.PropertyDefinition;
3332
import software.amazon.jdbc.plugin.AbstractConnectionPlugin;
3433
import software.amazon.jdbc.util.Messages;
3534
import software.amazon.jdbc.util.StringUtils;
@@ -39,6 +38,9 @@
3938

4039
public class DataRemoteCachePlugin extends AbstractConnectionPlugin {
4140
private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName());
41+
private static final String QUERY_HINT_START_PATTERN = "/*+";
42+
private static final String QUERY_HINT_END_PATTERN = "*/";
43+
private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM(";
4244
private static final Set<String> subscribedMethods = Collections.unmodifiableSet(new HashSet<>(
4345
Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName,
4446
JdbcMethod.STATEMENT_EXECUTE.methodName,
@@ -52,6 +54,7 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin {
5254
private TelemetryCounter hitCounter;
5355
private TelemetryCounter missCounter;
5456
private TelemetryCounter totalCallsCounter;
57+
private TelemetryCounter malformedHintCounter;
5558
private CacheConnection cacheConnection;
5659

5760
public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) {
@@ -66,6 +69,7 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties
6669
this.hitCounter = telemetryFactory.createCounter("remoteCache.cache.hit");
6770
this.missCounter = telemetryFactory.createCounter("remoteCache.cache.miss");
6871
this.totalCallsCounter = telemetryFactory.createCounter("remoteCache.cache.totalCalls");
72+
this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint");
6973
this.cacheConnection = new CacheConnection(properties);
7074
}
7175

@@ -133,33 +137,68 @@ private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) thro
133137

134138
/**
135139
* Determine the TTL based on an input query
136-
* @param queryHint string. e.g. "NO CACHE", or "cacheTTL=100s"
140+
* @param queryHint string. e.g. "CACHE_PARAM(ttl=100s, key=custom)"
137141
* @return TTL in seconds to cache the query.
138142
* null if the query is not cacheable.
139143
*/
140144
protected Integer getTtlForQuery(String queryHint) {
141145
// Empty query is not cacheable
142146
if (StringUtils.isNullOrEmpty(queryHint)) return null;
143-
// Query longer than 16K is not cacheable
144-
String[] tokens = queryHint.toLowerCase().split("cache");
145-
if (tokens.length >= 2) {
146-
// Handle "no cache".
147-
if (!StringUtils.isNullOrEmpty(tokens[0]) && "no".equals(tokens[0])) return null;
148-
// Handle "cacheTTL=Xs"
149-
if (!StringUtils.isNullOrEmpty(tokens[1]) && tokens[1].startsWith("ttl=")) {
150-
int endIndex = tokens[1].indexOf('s');
151-
if (endIndex > 0) {
147+
// Find CACHE_PARAM anywhere in the hint string (case insensitive)
148+
String upperHint = queryHint.toUpperCase();
149+
int cacheParamStart = upperHint.indexOf(CACHE_PARAM_PATTERN);
150+
if (cacheParamStart == -1) return null;
151+
152+
// Find the matching closing parenthesis
153+
int paramsStart = cacheParamStart + CACHE_PARAM_PATTERN.length();
154+
int paramsEnd = upperHint.indexOf(")", paramsStart);
155+
if (paramsEnd == -1) return null;
156+
157+
// Extract parameters between parentheses
158+
String cacheParams = upperHint.substring(paramsStart, paramsEnd).trim();
159+
// Empty parameters
160+
if (StringUtils.isNullOrEmpty(cacheParams)) {
161+
LOGGER.warning("Empty CACHE_PARAM parameters");
162+
incrCounter(malformedHintCounter);
163+
return null;
164+
}
165+
166+
// Parse comma-separated parameters
167+
String[] params = cacheParams.split(",");
168+
Integer ttlValue = null;
169+
170+
for (String param : params) {
171+
String[] keyValue = param.trim().split("=");
172+
if (keyValue.length != 2) {
173+
LOGGER.warning("Invalid caching parameter format: " + param);
174+
incrCounter(malformedHintCounter);
175+
return null;
176+
}
177+
String key = keyValue[0].trim();
178+
String value = keyValue[1].trim();
179+
180+
if ("TTL".equals(key)) {
181+
if (!value.endsWith("S")) {
182+
LOGGER.warning("TTL must end with 's': " + value);
183+
incrCounter(malformedHintCounter);
184+
return null;
185+
} else{
186+
// Parse TTL value (e.g., "300s")
152187
try {
153-
return Integer.parseInt(tokens[1].substring(4, endIndex));
154-
} catch (Exception e) {
155-
LOGGER.warning("Encountered exception when parsing Cache TTL: " + e.getMessage());
188+
ttlValue = Integer.parseInt(value.substring(0, value.length() - 1));
189+
// treat negative and 0 ttls as not cacheable
190+
if (ttlValue <= 0) {
191+
return null;
192+
}
193+
} catch (NumberFormatException e) {
194+
LOGGER.warning(String.format("Invalid TTL format of %s for query %s", value, queryHint));
195+
incrCounter(malformedHintCounter);
196+
return null;
156197
}
157198
}
158199
}
159200
}
160-
161-
LOGGER.finest("Query hint " + queryHint + " indicates the query is not cacheable");
162-
return null;
201+
return ttlValue;
163202
}
164203

165204
@Override
@@ -181,8 +220,9 @@ public <T, E extends Exception> T execute(
181220
String mainQuery = sql; // The main part of the query with the query hint prefix trimmed
182221
int endOfQueryHint = 0;
183222
Integer configuredQueryTtl = null;
184-
if ((sql.length() < 16000) && sql.startsWith("/*")) {
185-
endOfQueryHint = sql.indexOf("*/");
223+
// Queries longer than 16KB is not cacheable
224+
if ((sql.length() < 16000) && sql.startsWith(QUERY_HINT_START_PATTERN)) {
225+
endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN);
186226
if (endOfQueryHint > 0) {
187227
configuredQueryTtl = getTtlForQuery(sql.substring(2, endOfQueryHint).trim());
188228
mainQuery = sql.substring(endOfQueryHint + 2).trim();

wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java

Lines changed: 126 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class DataRemoteCachePluginTest {
3434
@Mock TelemetryCounter mockHitCounter;
3535
@Mock TelemetryCounter mockMissCounter;
3636
@Mock TelemetryCounter mockTotalCallsCounter;
37+
@Mock TelemetryCounter mockMalformedHintCounter;
3738
@Mock ResultSet mockResult1;
3839
@Mock Statement mockStatement;
3940
@Mock ResultSetMetaData mockMetaData;
@@ -52,6 +53,7 @@ void setUp() throws SQLException {
5253
when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter);
5354
when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter);
5455
when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter);
56+
when(mockTelemetryFactory.createCounter("JdbcCacheMalformedQueryHint")).thenReturn(mockMalformedHintCounter);
5557
when(mockResult1.getMetaData()).thenReturn(mockMetaData);
5658
when(mockMetaData.getColumnCount()).thenReturn(1);
5759
when(mockMetaData.getColumnLabel(1)).thenReturn("fooName");
@@ -66,36 +68,68 @@ void cleanUp() throws Exception {
6668

6769
@Test
6870
void test_getTTLFromQueryHint() throws Exception {
69-
// Null and empty query string are not cacheable
71+
// Null and empty query hint content are not cacheable
7072
assertNull(plugin.getTtlForQuery(null));
7173
assertNull(plugin.getTtlForQuery(""));
7274
assertNull(plugin.getTtlForQuery(" "));
73-
// Some other query hint
74-
assertNull(plugin.getTtlForQuery("/* cacheNotEnabled */ select * from T"));
75-
// Rule set is empty. All select queries are cached with 300 seconds TTL
76-
String selectQuery1 = "cachettl=300s";
77-
String selectQuery2 = " /* CACHETTL=100s */ SELECT ID from mytable2 ";
78-
String selectQuery3 = "/*CacheTTL=35s*/select * from table3 where ID = 1 and name = 'tom'";
79-
// Query hints that are not cacheable
80-
String selectQueryNoHint = "select * from table4";
81-
String selectQueryNoCache1 = "no cache";
82-
String selectQueryNoCache2 = " /* NO CACHE */ SELECT count(*) FROM (select player_id from roster where id = 1 FOR UPDATE) really_long_name_alias";
83-
String selectQueryNoCache3 = "/* cachettl=300 */ SELECT count(*) FROM (select player_id from roster where id = 1) really_long_name_alias";
84-
85-
// Non select queries are not cacheable
86-
String veryShortQuery = "BEGIN";
87-
String insertQuery = "/* This is an insert query */ insert into mytable values (1, 2)";
88-
String updateQuery = "/* Update query */ Update /* Another hint */ mytable set val = 1";
89-
assertEquals(300, plugin.getTtlForQuery(selectQuery1));
90-
assertEquals(100, plugin.getTtlForQuery(selectQuery2));
91-
assertEquals(35, plugin.getTtlForQuery(selectQuery3));
92-
assertNull(plugin.getTtlForQuery(selectQueryNoHint));
93-
assertNull(plugin.getTtlForQuery(selectQueryNoCache1));
94-
assertNull(plugin.getTtlForQuery(selectQueryNoCache2));
95-
assertNull(plugin.getTtlForQuery(selectQueryNoCache3));
96-
assertNull(plugin.getTtlForQuery(veryShortQuery));
97-
assertNull(plugin.getTtlForQuery(insertQuery));
98-
assertNull(plugin.getTtlForQuery(updateQuery));
75+
// Valid CACHE_PARAM cases - these are the hint contents after /*+ and before */
76+
assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s)"));
77+
assertEquals(100, plugin.getTtlForQuery("CACHE_PARAM(ttl=100s)"));
78+
assertEquals(35, plugin.getTtlForQuery("CACHE_PARAM(ttl=35s)"));
79+
80+
// Case insensitive
81+
assertEquals(200, plugin.getTtlForQuery("cache_param(ttl=200s)"));
82+
assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150s)"));
83+
assertEquals(200, plugin.getTtlForQuery("cache_param(tTl=200s)"));
84+
assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150S)"));
85+
assertEquals(200, plugin.getTtlForQuery("cache_param(TTL=200S)"));
86+
87+
// CACHE_PARAM anywhere in hint content (mixed with other hint directives)
88+
assertEquals(250, plugin.getTtlForQuery("INDEX(table1 idx1) CACHE_PARAM(ttl=250s)"));
89+
assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM(ttl=200s) USE_NL(t1 t2)"));
90+
assertEquals(180, plugin.getTtlForQuery("FIRST_ROWS(10) CACHE_PARAM(ttl=180s) PARALLEL(4)"));
91+
assertEquals(200, plugin.getTtlForQuery("foo=bar,CACHE_PARAM(ttl=200s),baz=qux"));
92+
93+
// Whitespace handling
94+
assertEquals(400, plugin.getTtlForQuery("CACHE_PARAM( ttl=400s )"));
95+
assertEquals(500, plugin.getTtlForQuery("CACHE_PARAM(ttl = 500s)"));
96+
assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM( ttl = 200s , key = test )"));
97+
98+
// Invalid cases - no CACHE_PARAM in hint content
99+
assertNull(plugin.getTtlForQuery("INDEX(table1 idx1)"));
100+
assertNull(plugin.getTtlForQuery("FIRST_ROWS(100)"));
101+
assertNull(plugin.getTtlForQuery("cachettl=300s")); // old format
102+
assertNull(plugin.getTtlForQuery("NO_CACHE"));
103+
104+
// Missing parentheses
105+
assertNull(plugin.getTtlForQuery("CACHE_PARAM ttl=300s"));
106+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300s"));
107+
108+
// Multiple parameters (future-proofing)
109+
assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s, key=test)"));
110+
111+
// Large TTL values should work
112+
assertEquals(999999, plugin.getTtlForQuery("CACHE_PARAM(ttl=999999s)"));
113+
assertEquals(86400, plugin.getTtlForQuery("CACHE_PARAM(ttl=86400s)")); // 24 hours
114+
}
115+
116+
@Test
117+
void test_getTTLFromQueryHint_MalformedHints() throws Exception {
118+
// Test malformed cases
119+
assertNull(plugin.getTtlForQuery("CACHE_PARAM()"));
120+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=abc)"));
121+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300)")); // missing 's'
122+
123+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=)"));
124+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(invalid_format)"));
125+
126+
// Invalid TTL values (negative and zero) does not count toward malformed hints
127+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=0s)"));
128+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-10s)"));
129+
assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-1s)"));
130+
131+
// Verify counter was incremented 8 times (5 original + 3 new)
132+
verify(mockMalformedHintCounter, times(5)).inc();
99133
}
100134

101135
@Test
@@ -125,7 +159,7 @@ void test_execute_noCachingLongQuery() throws Exception {
125159
when(mockCallable.call()).thenReturn(mockResult1);
126160

127161
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
128-
methodName, mockCallable, new String[]{"/* cacheTTL=30s */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)});
162+
methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)});
129163

130164
// Mock result set containing 1 row
131165
when(mockResult1.next()).thenReturn(true, true, false, false);
@@ -153,7 +187,7 @@ void test_execute_cachingMissAndHit() throws Exception {
153187
when(mockResult1.getObject(1)).thenReturn("bar1");
154188

155189
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
156-
methodName, mockCallable, new String[]{"/*CACHETTL=100s*/ select * from A"});
190+
methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"});
157191

158192
// Cached result set contains 1 row
159193
assertTrue(rs.next());
@@ -163,7 +197,7 @@ void test_execute_cachingMissAndHit() throws Exception {
163197
byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray();
164198
when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet);
165199
ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
166-
methodName, mockCallable, new String[]{" /* CacheTtl=50s */select * from A"});
200+
methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"});
167201

168202
assertTrue(rs2.next());
169203
assertEquals("bar1", rs2.getString("fooName"));
@@ -172,7 +206,7 @@ void test_execute_cachingMissAndHit() throws Exception {
172206
verify(mockPluginService, times(2)).isInTransaction();
173207
verify(mockCacheConn, times(2)).readFromCache("public_user_select * from A");
174208
verify(mockCallable).call();
175-
verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(100));
209+
verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50));
176210
verify(mockTotalCallsCounter, times(2)).inc();
177211
verify(mockMissCounter).inc();
178212
verify(mockHitCounter).inc();
@@ -193,7 +227,7 @@ void test_transaction_cacheQuery() throws Exception {
193227
when(mockResult1.getObject(1)).thenReturn("bar1");
194228

195229
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
196-
methodName, mockCallable, new String[]{"/* cacheTTL=300s */ select * from T"});
230+
methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s) */ select * from T"});
197231

198232
// Cached result set contains 1 row
199233
assertTrue(rs.next());
@@ -210,6 +244,66 @@ void test_transaction_cacheQuery() throws Exception {
210244
}
211245

212246
@Test
247+
void test_transaction_cacheQuery_multiple_query_params() throws Exception {
248+
// Query is cacheable
249+
when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection);
250+
when(mockPluginService.isInTransaction()).thenReturn(true);
251+
when(mockConnection.getMetaData()).thenReturn(mockDbMetadata);
252+
when(mockConnection.getSchema()).thenReturn("public");
253+
when(mockDbMetadata.getUserName()).thenReturn("user");
254+
when(mockCallable.call()).thenReturn(mockResult1);
255+
256+
// Result set contains 1 row
257+
when(mockResult1.next()).thenReturn(true, false);
258+
when(mockResult1.getObject(1)).thenReturn("bar1");
259+
260+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s, otherParam=abc) */ select * from T"});
261+
262+
// Cached result set contains 1 row
263+
assertTrue(rs.next());
264+
assertEquals("bar1", rs.getString("fooName"));
265+
assertFalse(rs.next());
266+
verify(mockPluginService).getCurrentConnection();
267+
verify(mockPluginService).isInTransaction();
268+
verify(mockCacheConn, never()).readFromCache(anyString());
269+
verify(mockCallable).call();
270+
verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300));
271+
verify(mockTotalCallsCounter, never()).inc();
272+
verify(mockHitCounter, never()).inc();
273+
verify(mockMissCounter, never()).inc();
274+
}
275+
276+
@Test
277+
void test_transaction_cacheQuery_multiple_query_hints() throws Exception {// Query is cacheable
278+
when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection);
279+
when(mockPluginService.isInTransaction()).thenReturn(true);
280+
when(mockConnection.getMetaData()).thenReturn(mockDbMetadata);
281+
when(mockConnection.getSchema()).thenReturn("public");
282+
when(mockDbMetadata.getUserName()).thenReturn("user");
283+
when(mockCallable.call()).thenReturn(mockResult1);
284+
285+
// Result set contains 1 row
286+
when(mockResult1.next()).thenReturn(true, false);
287+
when(mockResult1.getObject(1)).thenReturn("bar1");
288+
289+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
290+
methodName, mockCallable, new String[]{"/*+ hello CACHE_PARAM(ttl=300s, otherParam=abc) world */ select * from T"});
291+
292+
// Cached result set contains 1 row
293+
assertTrue(rs.next());
294+
assertEquals("bar1", rs.getString("fooName"));
295+
assertFalse(rs.next());
296+
verify(mockPluginService).getCurrentConnection();
297+
verify(mockPluginService).isInTransaction();
298+
verify(mockCacheConn, never()).readFromCache(anyString());
299+
verify(mockCallable).call();
300+
verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300));
301+
verify(mockTotalCallsCounter, never()).inc();
302+
verify(mockHitCounter, never()).inc();
303+
verify(mockMissCounter, never()).inc();
304+
}
305+
306+
@Test
213307
void test_transaction_noCaching() throws Exception {
214308
// Query is not cacheable
215309
when(mockPluginService.isInTransaction()).thenReturn(true);

0 commit comments

Comments
 (0)