Skip to content

Commit 213d89f

Browse files
committed
Caching - don't cache queries that are part of a multi-statement transaction. Add unit test for CacheConnection logic.
1 parent 8c4f65d commit 213d89f

File tree

3 files changed

+164
-29
lines changed

3 files changed

+164
-29
lines changed

wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io.lettuce.core.RedisClient;
44
import io.lettuce.core.RedisCommandExecutionException;
55
import io.lettuce.core.RedisURI;
6+
import io.lettuce.core.SetArgs;
67
import io.lettuce.core.api.StatefulRedisConnection;
78
import io.lettuce.core.api.async.RedisAsyncCommands;
89
import io.lettuce.core.codec.ByteArrayCodec;
@@ -174,38 +175,40 @@ public byte[] readFromCache(String key) {
174175
}
175176
}
176177

178+
protected void handleCompletedCacheWrite(StatefulRedisConnection<byte[], byte[]> conn, Throwable ex) {
179+
// Note: this callback upon completion of cache write is on a different thread
180+
if (ex != null) {
181+
LOGGER.warning("Failed to write to cache: " + ex.getMessage());
182+
if (writeConnectionPool != null) {
183+
try {
184+
returnConnectionBackToPool(conn, true, false);
185+
} catch (Exception e) {
186+
LOGGER.warning("Error returning broken write connection back to pool: " + e.getMessage());
187+
}
188+
}
189+
} else {
190+
if (writeConnectionPool != null) {
191+
try {
192+
returnConnectionBackToPool(conn, false, false);
193+
} catch (Exception e) {
194+
LOGGER.warning("Error returning write connection back to pool: " + e.getMessage());
195+
}
196+
}
197+
}
198+
}
199+
177200
public void writeToCache(String key, byte[] value, int expiry) {
178201
StatefulRedisConnection<byte[], byte[]> conn = null;
179202
try {
180203
initializeCacheConnectionIfNeeded(false);
181204
// get a connection from the write connection pool
182205
conn = writeConnectionPool.borrowObject();
183-
// Add support to make write to the cache to be async.
206+
// Write to the cache is async.
184207
RedisAsyncCommands<byte[], byte[]> asyncCommands = conn.async();
185208
byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8));
186-
187209
StatefulRedisConnection<byte[], byte[]> finalConn = conn;
188-
asyncCommands.setex(keyHash, expiry, value)
189-
.whenComplete((result, exception) -> {
190-
if (exception != null) {
191-
LOGGER.warning("Failed to write to cache: " + exception.getMessage());
192-
if (writeConnectionPool != null) {
193-
try {
194-
returnConnectionBackToPool(finalConn, true, false);
195-
} catch (Exception ex) {
196-
LOGGER.warning("Error returning broken write connection back to pool: " + ex.getMessage());
197-
}
198-
}
199-
} else {
200-
if (writeConnectionPool != null) {
201-
try {
202-
returnConnectionBackToPool(finalConn, false, false);
203-
} catch (Exception ex) {
204-
LOGGER.warning("Error returning write connection back to pool: " + ex.getMessage());
205-
}
206-
}
207-
}
208-
});
210+
asyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry))
211+
.whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, exception));
209212
} catch (Exception e) {
210213
LOGGER.warning("Failed to write to cache: " + e.getMessage());
211214
if (conn != null && writeConnectionPool != null) {
@@ -230,4 +233,11 @@ private void returnConnectionBackToPool(StatefulRedisConnection <byte[], byte[]>
230233
pool.returnObject(connection);
231234
}
232235
}
236+
237+
// Used for unit testing only
238+
protected void setConnectionPools(GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> readPool,
239+
GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> writePool) {
240+
readConnectionPool = readPool;
241+
writeConnectionPool = writePool;
242+
}
233243
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package software.amazon.jdbc.plugin.cache;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNull;
5+
6+
import io.lettuce.core.RedisFuture;
7+
import io.lettuce.core.api.StatefulRedisConnection;
8+
import io.lettuce.core.api.async.RedisAsyncCommands;
9+
import io.lettuce.core.api.sync.RedisCommands;
10+
import org.apache.commons.pool2.impl.GenericObjectPool;
11+
import org.junit.jupiter.api.AfterEach;
12+
import org.junit.jupiter.api.BeforeEach;
13+
import org.junit.jupiter.api.Test;
14+
import org.mockito.Mock;
15+
import org.mockito.MockitoAnnotations;
16+
import java.nio.charset.StandardCharsets;
17+
import java.util.function.BiConsumer;
18+
import java.util.Properties;
19+
20+
import static org.mockito.Mockito.*;
21+
22+
public class CacheConnectionTest {
23+
@Mock GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> mockReadConnPool;
24+
@Mock GenericObjectPool<StatefulRedisConnection<byte[], byte[]>> mockWriteConnPool;
25+
@Mock StatefulRedisConnection<byte[], byte[]> mockConnection;
26+
@Mock RedisCommands<byte[], byte[]> mockSyncCommands;
27+
@Mock RedisAsyncCommands<byte[], byte[]> mockAsyncCommands;
28+
@Mock RedisFuture<String> mockCacheResult;
29+
private AutoCloseable closeable;
30+
private CacheConnection cacheConnection;
31+
32+
@BeforeEach
33+
void setUp() {
34+
closeable = MockitoAnnotations.openMocks(this);
35+
Properties props = new Properties();
36+
props.setProperty("wrapperPlugins", "dataRemoteCache");
37+
props.setProperty("cacheEndpointAddrRw", "localhost:6379");
38+
props.setProperty("cacheEndpointAddrRo", "localhost:6380");
39+
cacheConnection = new CacheConnection(props);
40+
cacheConnection.setConnectionPools(mockReadConnPool, mockWriteConnPool);
41+
}
42+
43+
@AfterEach
44+
void cleanUp() throws Exception {
45+
closeable.close();
46+
}
47+
48+
@Test
49+
void test_writeToCache() throws Exception {
50+
String key = "myQueryKey";
51+
byte[] value = "myValue".getBytes(StandardCharsets.UTF_8);
52+
when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection);
53+
when(mockConnection.async()).thenReturn(mockAsyncCommands);
54+
when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult);
55+
when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null);
56+
cacheConnection.writeToCache(key, value, 100);
57+
verify(mockWriteConnPool).borrowObject();
58+
verify(mockConnection).async();
59+
verify(mockAsyncCommands).set(any(), any(), any());
60+
verify(mockCacheResult).whenComplete(any(BiConsumer.class));
61+
}
62+
63+
@Test
64+
void test_writeToCacheException() throws Exception {
65+
String key = "myQueryKey";
66+
byte[] value = "myValue".getBytes(StandardCharsets.UTF_8);
67+
when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection);
68+
when(mockConnection.async()).thenReturn(mockAsyncCommands);
69+
when(mockAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("test exception"));
70+
cacheConnection.writeToCache(key, value, 100);
71+
verify(mockWriteConnPool).borrowObject();
72+
verify(mockConnection).async();
73+
verify(mockAsyncCommands).set(any(), any(), any());
74+
verify(mockWriteConnPool).invalidateObject(mockConnection);
75+
}
76+
77+
@Test
78+
void test_handleCompletedCacheWrite() throws Exception {
79+
cacheConnection.handleCompletedCacheWrite(mockConnection, null);
80+
verify(mockWriteConnPool).returnObject(mockConnection);
81+
cacheConnection.handleCompletedCacheWrite(mockConnection, new RuntimeException("test"));
82+
verify(mockWriteConnPool).invalidateObject(mockConnection);
83+
}
84+
85+
@Test
86+
void test_readFromCache() throws Exception {
87+
byte[] value = "myValue".getBytes(StandardCharsets.UTF_8);
88+
when(mockReadConnPool.borrowObject()).thenReturn(mockConnection);
89+
when(mockConnection.sync()).thenReturn(mockSyncCommands);
90+
when(mockSyncCommands.get(any())).thenReturn(value);
91+
byte[] result = cacheConnection.readFromCache("myQueryKey");
92+
assertEquals(value, result);
93+
verify(mockReadConnPool).borrowObject();
94+
verify(mockConnection).sync();
95+
verify(mockSyncCommands).get(any());
96+
verify(mockReadConnPool).returnObject(mockConnection);
97+
}
98+
99+
@Test
100+
void test_readFromCacheException() throws Exception {
101+
when(mockReadConnPool.borrowObject()).thenReturn(mockConnection);
102+
when(mockConnection.sync()).thenReturn(mockSyncCommands);
103+
when(mockSyncCommands.get(any())).thenThrow(new RuntimeException("test"));
104+
assertNull(cacheConnection.readFromCache("myQueryKey"));
105+
verify(mockReadConnPool).borrowObject();
106+
verify(mockConnection).sync();
107+
verify(mockSyncCommands).get(any());
108+
verify(mockReadConnPool).invalidateObject(mockConnection);
109+
}
110+
}

wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ void setUp() throws SQLException {
5252
when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter);
5353
when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter);
5454
when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter);
55-
5655
when(mockResult1.getMetaData()).thenReturn(mockMetaData);
5756
when(mockMetaData.getColumnCount()).thenReturn(1);
5857
when(mockMetaData.getColumnName(1)).thenReturn("fooName");
@@ -99,6 +98,22 @@ void test_getTTLFromQueryHint() throws Exception {
9998
assertNull(plugin.getTtlForQuery(updateQuery));
10099
}
101100

101+
@Test
102+
void test_inTransaction_noCaching() throws Exception {
103+
// Query is not cacheable
104+
when(mockPluginService.isInTransaction()).thenReturn(true);
105+
when(mockCallable.call()).thenReturn(mockResult1);
106+
ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement,
107+
methodName, mockCallable, new String[]{"/* cacheTtl=50s */ select * from B"});
108+
109+
// Mock result set containing 1 row
110+
when(mockResult1.next()).thenReturn(true, true, false, false);
111+
when(mockResult1.getObject(1)).thenReturn("bar1", "bar1");
112+
compareResults(mockResult1, rs);
113+
verify(mockCallable).call();
114+
verify(mockTotalCallsCounter).inc();
115+
}
116+
102117
@Test
103118
void test_execute_noCaching() throws Exception {
104119
// Query is not cacheable
@@ -158,7 +173,7 @@ void test_execute_cachingMiss() throws Exception {
158173

159174
// Cached result set contains 1 row
160175
assertTrue(rs.next());
161-
assertEquals(rs.getString("fooName"), "bar1");
176+
assertEquals("bar1", rs.getString("fooName"));
162177
assertFalse(rs.next());
163178
verify(mockPluginService, times(2)).getCurrentConnection();
164179
verify(mockPluginService).isInTransaction();
@@ -191,11 +206,11 @@ void test_execute_cachingHit() throws Exception {
191206

192207
// Cached result set contains 2 rows
193208
assertTrue(rs.next());
194-
assertEquals(rs.getString("date"), "2009-09-30");
195-
assertEquals(rs.getString("code"), "avata");
209+
assertEquals("2009-09-30", rs.getString("date"));
210+
assertEquals("avata", rs.getString("code"));
196211
assertTrue(rs.next());
197-
assertEquals(rs.getString("date"), "2015-05-30");
198-
assertEquals(rs.getString("code"), "dracu");
212+
assertEquals("2015-05-30", rs.getString("date"));
213+
assertEquals("dracu", rs.getString("code"));
199214
assertFalse(rs.next());
200215
verify(mockPluginService).getCurrentConnection();
201216
verify(mockPluginService).isInTransaction();

0 commit comments

Comments
 (0)