From 792dc19dd0900ba2f6d1ec543dbd4cfb6cc634a7 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Fri, 14 Mar 2025 15:59:17 -0700 Subject: [PATCH 01/24] Initial prototype for remote query cache plugin. Re-factored the common code for CachedResultSet. --- examples/AWSDriverExample/build.gradle.kts | 1 + .../amazon/PgConnectionWithCacheExample.java | 39 + wrapper/build.gradle.kts | 2 + .../jdbc/ConnectionPluginChainBuilder.java | 3 + .../amazon/jdbc/ConnectionPluginManager.java | 2 + .../plugin/DataCacheConnectionPlugin.java | 1074 +--------------- .../jdbc/plugin/DataRemoteCachePlugin.java | 220 ++++ .../plugin/DataRemoteCachePluginFactory.java | 30 + .../amazon/jdbc/states/SessionState.java | 2 + .../amazon/jdbc/util/CacheConnection.java | 69 + .../amazon/jdbc/util/CachedResultSet.java | 1128 +++++++++++++++++ wrapper/src/test/build.gradle.kts | 1 + 12 files changed, 1498 insertions(+), 1073 deletions(-) create mode 100644 examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index 08fee3f19..d43e570aa 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -29,6 +29,7 @@ dependencies { implementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") implementation("org.jsoup:jsoup:1.21.1") implementation("com.mchange:c3p0:0.11.0") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") } tasks.withType { diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java new file mode 100644 index 000000000..99c40dc45 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java @@ -0,0 +1,39 @@ +package software.amazon; + +import java.sql.*; +import java.util.*; + +public class PgConnectionWithCacheExample { + + private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:5432/postgres"; + private static final String CACHE_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com"; + private static final String USERNAME = "postgres"; + private static final String PASSWORD = "adminadmin"; + + public static void main(String[] args) throws SQLException { + final Properties properties = new Properties(); + + // Configuring connection properties for the underlying JDBC driver. + properties.setProperty("user", USERNAME); + properties.setProperty("password", PASSWORD); + + // Configuring connection properties for the JDBC Wrapper. + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_SERVER_ADDR); + properties.setProperty("wrapperLogUnclosedConnections", "true"); + String queryStr = "select * from cinemas"; + + for (int i = 0 ; i < 5; i++) { + // Create a new database connection and issue a query to it + try (Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery(queryStr)) { + System.out.println("Executed the SQL query with result set: " + rs.toString()); + Thread.sleep(2000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + +} diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 0e6fb29f1..107db7b80 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -50,6 +50,7 @@ dependencies { optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") + compileOnly("io.lettuce:lettuce-core:6.6.0.RELEASE") compileOnly("org.checkerframework:checker-qual:3.49.5") compileOnly("com.mysql:mysql-connector-j:9.4.0") compileOnly("org.postgresql:postgresql:42.7.7") @@ -98,6 +99,7 @@ dependencies { testImplementation("org.slf4j:slf4j-simple:2.0.17") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("io.opentelemetry:opentelemetry-api:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 952b00936..7548a4a7c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -34,6 +34,7 @@ import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPluginFactory; import software.amazon.jdbc.plugin.ConnectTimeConnectionPluginFactory; import software.amazon.jdbc.plugin.DataCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.DataRemoteCachePluginFactory; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPluginFactory; @@ -69,6 +70,7 @@ public class ConnectionPluginChainBuilder { put("executionTime", new ExecutionTimeConnectionPluginFactory()); put("logQuery", new LogQueryConnectionPluginFactory()); put("dataCache", new DataCacheConnectionPluginFactory()); + put("dataRemoteCache", DataRemoteCachePluginFactory.class); put("customEndpoint", new CustomEndpointPluginFactory()); put("efm", new HostMonitoringConnectionPluginFactory()); put("efm2", new software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPluginFactory()); @@ -101,6 +103,7 @@ public class ConnectionPluginChainBuilder { { put(DriverMetaDataConnectionPluginFactory.class, 100); put(DataCacheConnectionPluginFactory.class, 200); + put(DataRemoteCachePluginFactory.class, 250); put(CustomEndpointPluginFactory.class, 380); put(AuroraInitialConnectionStrategyPluginFactory.class, 390); put(AuroraConnectionTrackerPluginFactory.class, 400); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 2697c5b03..da7a12742 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -34,6 +34,7 @@ import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin; import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.DataRemoteCachePlugin; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; @@ -73,6 +74,7 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(AuroraConnectionTrackerPlugin.class, "plugin:auroraConnectionTracker"); put(LogQueryConnectionPlugin.class, "plugin:logQuery"); put(DataCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataRemoteCachePlugin.class, "plugin:dataRemoteCache"); put(HostMonitoringConnectionPlugin.class, "plugin:efm"); put(software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin.class, "plugin:efm2"); put(FailoverConnectionPlugin.class, "plugin:failover"); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java index a7cf53a9c..1c95f953d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java @@ -16,30 +16,10 @@ package software.amazon.jdbc.plugin; -import java.io.InputStream; -import java.io.Reader; -import java.math.BigDecimal; -import java.net.URL; -import java.sql.Array; -import java.sql.Blob; -import java.sql.Clob; -import java.sql.Date; -import java.sql.NClob; -import java.sql.Ref; import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.RowId; import java.sql.SQLException; -import java.sql.SQLWarning; -import java.sql.SQLXML; -import java.sql.Statement; -import java.sql.Time; -import java.sql.Timestamp; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Calendar; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Properties; @@ -51,6 +31,7 @@ import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.util.CachedResultSet; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -183,1057 +164,4 @@ protected String getQuery(final Object[] jdbcMethodArgs) { return null; } - public static class CachedRow { - protected final HashMap columnByIndex = new HashMap<>(); - protected final HashMap columnByName = new HashMap<>(); - - public void put(final int columnIndex, final String columnName, final Object columnValue) { - columnByIndex.put(columnIndex, columnValue); - columnByName.put(columnName, columnValue); - } - - @SuppressWarnings("unused") - public Object get(final int columnIndex) { - return columnByIndex.get(columnIndex); - } - - @SuppressWarnings("unused") - public Object get(final String columnName) { - return columnByName.get(columnName); - } - } - - @SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) - public static class CachedResultSet implements ResultSet { - - protected ArrayList rows; - protected int currentRow; - - public CachedResultSet(final ResultSet resultSet) throws SQLException { - - final ResultSetMetaData md = resultSet.getMetaData(); - final int columns = md.getColumnCount(); - rows = new ArrayList<>(); - - while (resultSet.next()) { - final CachedRow row = new CachedRow(); - for (int i = 1; i <= columns; ++i) { - row.put(i, md.getColumnName(i), resultSet.getObject(i)); - } - rows.add(row); - } - currentRow = -1; - } - - @Override - public boolean next() throws SQLException { - if (rows.size() == 0 || isLast()) { - return false; - } - currentRow++; - return true; - } - - @Override - public void close() throws SQLException { - currentRow = rows.size() - 1; - } - - @Override - public boolean wasNull() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getInt(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public byte[] getBytes(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getAsciiStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @Deprecated - public InputStream getUnicodeStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public InputStream getBinaryStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void clearWarnings() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getCursorName() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public ResultSetMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByIndex.containsKey(columnIndex)) { - return null; // column index out of boundaries - } - return row.columnByIndex.get(columnIndex); - } - - @Override - public Object getObject(final String columnLabel) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByName.containsKey(columnLabel)) { - return null; // column name not found - } - return row.columnByName.get(columnLabel); - } - - @Override - public int findColumn(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isBeforeFirst() throws SQLException { - return this.currentRow < 0; - } - - @Override - public boolean isAfterLast() throws SQLException { - return this.currentRow >= this.rows.size(); - } - - @Override - public boolean isFirst() throws SQLException { - return this.currentRow == 0 && this.rows.size() > 0; - } - - @Override - public boolean isLast() throws SQLException { - return this.currentRow == (this.rows.size() - 1) && this.rows.size() > 0; - } - - @Override - public void beforeFirst() throws SQLException { - this.currentRow = -1; - } - - @Override - public void afterLast() throws SQLException { - this.currentRow = this.rows.size(); - } - - @Override - public boolean first() throws SQLException { - this.currentRow = 0; - return this.currentRow < this.rows.size(); - } - - @Override - public boolean last() throws SQLException { - this.currentRow = this.rows.size() - 1; - return this.currentRow >= 0; - } - - @Override - public int getRow() throws SQLException { - return this.currentRow + 1; - } - - @Override - public boolean absolute(final int row) throws SQLException { - if (row > 0) { - this.currentRow = row - 1; - } else { - this.currentRow = this.rows.size() + row; - } - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean relative(final int rows) throws SQLException { - this.currentRow += rows; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public boolean previous() throws SQLException { - this.currentRow--; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); - } - - @Override - public void setFetchDirection(final int direction) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchDirection() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void setFetchSize(final int rows) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getFetchSize() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getType() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getConcurrency() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowUpdated() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowInserted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean rowDeleted() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final int columnIndex, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final int columnIndex, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final int columnIndex, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final int columnIndex, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final int columnIndex, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final int columnIndex, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final int columnIndex, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final int columnIndex, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final int columnIndex, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final int columnIndex, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNull(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateByte(final String columnLabel, final byte x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateShort(final String columnLabel, final short x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateInt(final String columnLabel, final int x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateLong(final String columnLabel, final long x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateFloat(final String columnLabel, final float x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDouble(final String columnLabel, final double x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateString(final String columnLabel, final String x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateDate(final String columnLabel, final Date x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTime(final String columnLabel, final Time x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateObject(final String columnLabel, final Object x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void insertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void deleteRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void refreshRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void cancelRowUpdates() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToInsertRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void moveToCurrentRow() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Statement getStatement() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final int columnIndex, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Object getObject(final String columnLabel, final Map> map) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Ref getRef(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Blob getBlob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Clob getClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Array getArray(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public URL getURL(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final int columnIndex, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRef(final String columnLabel, final Ref x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final Blob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Clob x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final int columnIndex, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateArray(final String columnLabel, final Array x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public RowId getRowId(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final int columnIndex, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateRowId(final String columnLabel, final RowId x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public int getHoldability() throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isClosed() throws SQLException { - return false; - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final int columnIndex, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNString(final String columnLabel, final String nString) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) - public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - @SuppressWarnings("checkstyle:MethodName") - public NClob getNClob(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public NClob getNClob(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public SQLXML getSQLXML(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public String getNString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public Reader getNCharacterStream(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader, final long length) - throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { - throw new UnsupportedOperationException(); - } - - @Override - public T getObject(final int columnIndex, final Class type) throws SQLException { - return type.cast(getObject(columnIndex)); - } - - @Override - public T getObject(final String columnLabel, final Class type) throws SQLException { - return type.cast(getObject(columnLabel)); - } - - @Override - public T unwrap(final Class iface) throws SQLException { - return iface == ResultSet.class ? iface.cast(this) : null; - } - - @Override - public boolean isWrapperFor(final Class iface) throws SQLException { - return iface != null && iface.isAssignableFrom(this.getClass()); - } - } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java new file mode 100644 index 000000000..9c4c4f6b2 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java @@ -0,0 +1,220 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.CacheConnection; +import software.amazon.jdbc.util.CachedResultSet; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +public class DataRemoteCachePlugin extends AbstractConnectionPlugin { + + private static long connectTime = 0L; + private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); + + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( + Arrays.asList("Statement.executeQuery", "Statement.execute", + "PreparedStatement.execute", "PreparedStatement.executeQuery", + "CallableStatement.execute", "CallableStatement.executeQuery", + "connect", "forceConnect"))); + + static { + PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); + } + + private final PluginService pluginService; + private final TelemetryFactory telemetryFactory; + private final TelemetryCounter hitCounter; + private final TelemetryCounter missCounter; + private final TelemetryCounter totalCallsCounter; + private final CacheConnection cacheConnection; + + public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { + this.pluginService = pluginService; + this.telemetryFactory = pluginService.getTelemetryFactory(); + this.hitCounter = telemetryFactory.createCounter("remoteCache.cache.hit"); + this.missCounter = telemetryFactory.createCounter("remoteCache.cache.miss"); + this.totalCallsCounter = telemetryFactory.createCounter("remoteCache.cache.totalCalls"); + this.cacheConnection = new CacheConnection(properties); + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + private Connection connectHelper(JdbcCallable connectFunc) throws SQLException { + final long startTime = System.nanoTime(); + + final Connection result = connectFunc.call(); + + final long elapsedTimeNanos = System.nanoTime() - startTime; + connectTime += elapsedTimeNanos; + LOGGER.fine( + () -> Messages.get( + "DataRemoteCachePlugin.cacheConnectTime", + new Object[] {elapsedTimeNanos})); + return result; + } + + @Override + public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, + boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { + System.out.println("DataRemoteCachingPlugin.connect()..."); + return this.connectHelper(connectFunc); + } + + @Override + public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, + boolean isInitialConnection, JdbcCallable forceConnectFunc) + throws SQLException { + System.out.println("DataRemoteCachingPlugin.forceConnect()..."); + return this.connectHelper(forceConnectFunc); + } + + private String getCacheQueryKey(String query) { + // Check some basic session states. The important ones for caching include (but not limited to): + // schema name, username which can affect the query result from the DB in addition to the query string + try { + Connection currentConn = pluginService.getCurrentConnection(); + DatabaseMetaData metadata = currentConn.getMetaData(); + SessionStateService sessionStateService = pluginService.getSessionStateService(); + System.out.println("DB driver protocol " + pluginService.getDriverProtocol() + + ", schema: " + currentConn.getSchema() + + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() + + ", user: " + metadata.getUserName() + + ", driver: " + metadata.getDriverName() + " " + metadata.getDriverVersion()); + // The cache key contains the schema name, user name, and the query string + String[] words = {currentConn.getSchema(), metadata.getUserName(), query}; + return String.join("_", words); + } catch (SQLException e) { + System.out.println("Error getting session state: " + e.getMessage()); + return null; + } + } + + private ResultSet fetchResultSetFromCache(String queryStr) { + if (cacheConnection == null) return null; + + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return null; // Treat this as a cache miss + byte[] result = cacheConnection.readFromCache(cacheQueryKey); + if (result == null) return null; + + // Convert result into ResultSet + try { + return CachedResultSet.deserializeFromJsonString(new String(result, StandardCharsets.UTF_8)); + } catch (Exception e) { + System.out.println("Error de-serializing cached result: " + e.getMessage()); + return null; // Treat this as a cache miss + } + } + + private void cacheResultSet(String queryStr, ResultSet rs) throws SQLException { + System.out.println("Caching resultSet returned from postgres database ....... "); + String jsonValue = CachedResultSet.serializeIntoJsonString(rs); + + // Write the resultSet into the cache as a single key + String cacheQueryKey = getCacheQueryKey(queryStr); + if (cacheQueryKey == null) return; // Treat this condition as un-cacheable + cacheConnection.writeToCache(cacheQueryKey, jsonValue.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + totalCallsCounter.inc(); + + ResultSet result; + boolean needToCache = false; + final String sql = getQuery(jdbcMethodArgs); + + // Try to fetch SELECT query from the cache + if (!StringUtils.isNullOrEmpty(sql) && sql.startsWith("select ")) { + result = fetchResultSetFromCache(sql); + if (result == null) { + System.out.println("We got a cache MISS........."); + // Cache miss. Need to fetch result from the database + needToCache = true; + missCounter.inc(); + LOGGER.finest( + () -> Messages.get( + "DataRemoteCachePlugin.queryResultsCached", + new Object[]{methodName, sql})); + } else { + System.out.println("We got a cache hit........."); + // Cache hit. Return the cached result + hitCounter.inc(); + try { + result.beforeFirst(); + } catch (final SQLException ex) { + if (exceptionClass.isAssignableFrom(ex.getClass())) { + throw exceptionClass.cast(ex); + } + throw new RuntimeException(ex); + } + return resultClass.cast(result); + } + } + + result = (ResultSet) jdbcMethodFunc.call(); + + if (needToCache) { + try { + cacheResultSet(sql, result); + result.beforeFirst(); + } catch (final SQLException ex) { + // ignore exception + System.out.println("Encountered SQLException when caching results..."); + } + } + + return resultClass.cast(result); + } + + protected String getQuery(final Object[] jdbcMethodArgs) { + // Get query from method argument + if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { + return jdbcMethodArgs[0].toString(); + } + return null; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java new file mode 100644 index 000000000..efce32fa7 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java @@ -0,0 +1,30 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin; + +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; + +public class DataRemoteCachePluginFactory implements ConnectionPluginFactory { + + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + return new DataRemoteCachePlugin(pluginService, props); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java b/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java index f29f915f0..708b0442f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java +++ b/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java @@ -29,6 +29,8 @@ public class SessionState { public SessionStateField transactionIsolation = new SessionStateField<>(); public SessionStateField>> typeMap = new SessionStateField<>(); + // TODO: add support for session states that affects the query result from the database + public SessionState copy() { final SessionState newSessionState = new SessionState(); newSessionState.autoCommit = this.autoCommit.copy(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java new file mode 100644 index 000000000..6b53bfe67 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java @@ -0,0 +1,69 @@ +package software.amazon.jdbc.util; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.ClientResources; +import software.amazon.jdbc.AwsWrapperProperty; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Properties; + +// Abstraction layer on top of a connection to a remote cache server +public class CacheConnection { + // TODO: support connection pools to the remote cache server for read and write + private StatefulRedisConnection connection = null; + private final String cacheServerAddr; + private MessageDigest msgHashDigest = null; + + private static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRw", + null, + "The cache server endpoint address."); + + public CacheConnection(final Properties properties) { + this.cacheServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); + } + + private void initializeCacheConnectionIfNeeded() { + if (StringUtils.isNullOrEmpty(cacheServerAddr)) return; + // Initialize the message digest + if (msgHashDigest == null) { + try { + msgHashDigest = MessageDigest.getInstance("SHA-384"); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-384 not supported", e); + } + } + // Create a stateful redis connection with TLS enabled + if (connection == null) { + System.out.println("Now we are creating a new Redis connection......"); + ClientResources resources = ClientResources.builder().build(); + final RedisURI redisUriCluster = RedisURI.Builder.redis(cacheServerAddr). + withPort(6379).withSsl(true).withVerifyPeer(false).build(); + RedisClient clusterClient = RedisClient.create(resources, redisUriCluster); + connection = clusterClient.connect(new ByteArrayCodec()); + } + } + + // Get the hash digest of the given key. + private byte[] computeHashDigest(byte[] key) { + msgHashDigest.update(key); + return msgHashDigest.digest(); + } + + public byte[] readFromCache(String key) { + initializeCacheConnectionIfNeeded(); + // TODO: get a connection from the read connection pool + return connection.sync().get(computeHashDigest(key.getBytes(StandardCharsets.UTF_8))); + } + + public void writeToCache(String key, byte[] value) { + initializeCacheConnectionIfNeeded(); + // TODO: get a connection from the write connection pool + connection.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), 300, value); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java new file mode 100644 index 000000000..aeb2a2c2b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java @@ -0,0 +1,1128 @@ +package software.amazon.jdbc.util; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) +public class CachedResultSet implements ResultSet { + + public static class CachedRow { + protected final HashMap columnByIndex = new HashMap<>(); + protected final HashMap columnByName = new HashMap<>(); + + public void put(final int columnIndex, final String columnName, final Object columnValue) { + columnByIndex.put(columnIndex, columnValue); + columnByName.put(columnName, columnValue); + } + + @SuppressWarnings("unused") + public Object get(final int columnIndex) { + return columnByIndex.get(columnIndex); + } + + @SuppressWarnings("unused") + public Object get(final String columnName) { + return columnByName.get(columnName); + } + } + + protected ArrayList rows; + protected int currentRow; + + public CachedResultSet(final ResultSet resultSet) throws SQLException { + + final ResultSetMetaData md = resultSet.getMetaData(); + final int columns = md.getColumnCount(); + rows = new ArrayList<>(); + + while (resultSet.next()) { + final CachedRow row = new CachedRow(); + for (int i = 1; i <= columns; ++i) { + row.put(i, md.getColumnName(i), resultSet.getObject(i)); + } + rows.add(row); + } + currentRow = -1; + } + + public CachedResultSet(final List> resultList) { + rows = new ArrayList<>(); + for (Map rowMap : resultList) { + final CachedRow row = new CachedRow(); + int i = 1; + for (String columnName : rowMap.keySet()) { + row.put(i, columnName, rowMap.get(columnName)); + } + rows.add(row); + } + currentRow = -1; + } + + public static String serializeIntoJsonString(ResultSet rs) throws SQLException { + ObjectMapper mapper = new ObjectMapper(); + List> resultList = new ArrayList<>(); + ResultSetMetaData metaData = rs.getMetaData(); + int columns = metaData.getColumnCount(); + + while (rs.next()) { + Map rowMap = new HashMap<>(); + for (int i = 1; i <= columns; i++) { + rowMap.put(metaData.getColumnName(i), rs.getObject(i)); + } + resultList.add(rowMap); + } + try { + return mapper.writeValueAsString(resultList); + } catch (JsonProcessingException e) { + throw new SQLException("Error serializing ResultSet to JSON", e); + } + } + + public static ResultSet deserializeFromJsonString(String jsonString) throws SQLException { + if (jsonString == null || jsonString.isEmpty()) { return null; } + try { + ObjectMapper mapper = new ObjectMapper(); + List> resultList = mapper.readValue(jsonString, + mapper.getTypeFactory().constructCollectionType(List.class, Map.class)); + return new CachedResultSet(resultList); + } catch (JsonProcessingException e) { + throw new SQLException("Error de-serializing ResultSet from JSON", e); + } + } + + @Override + public boolean next() throws SQLException { + if (rows.size() == 0 || isLast()) { + return false; + } + currentRow++; + return true; + } + + @Override + public void close() throws SQLException { + currentRow = rows.size() - 1; + } + + @Override + public boolean wasNull() throws SQLException { + throw new UnsupportedOperationException(); + } + + // TODO: implement all the getXXX APIs. + @Override + public String getString(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBytes(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getAsciiStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean getBoolean(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBytes(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getAsciiStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public InputStream getUnicodeStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public InputStream getBinaryStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void clearWarnings() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getCursorName() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final int columnIndex) throws SQLException { + if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { + return null; // out of boundaries + } + final CachedRow row = this.rows.get(this.currentRow); + if (!row.columnByIndex.containsKey(columnIndex)) { + return null; // column index out of boundaries + } + return row.columnByIndex.get(columnIndex); + } + + @Override + public Object getObject(final String columnLabel) throws SQLException { + if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { + return null; // out of boundaries + } + final CachedRow row = this.rows.get(this.currentRow); + if (!row.columnByName.containsKey(columnLabel)) { + return null; // column name not found + } + return row.columnByName.get(columnLabel); + } + + @Override + public int findColumn(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isBeforeFirst() throws SQLException { + return this.currentRow < 0; + } + + @Override + public boolean isAfterLast() throws SQLException { + return this.currentRow >= this.rows.size(); + } + + @Override + public boolean isFirst() throws SQLException { + return this.currentRow == 0 && this.rows.size() > 0; + } + + @Override + public boolean isLast() throws SQLException { + return this.currentRow == (this.rows.size() - 1) && this.rows.size() > 0; + } + + @Override + public void beforeFirst() throws SQLException { + this.currentRow = -1; + } + + @Override + public void afterLast() throws SQLException { + this.currentRow = this.rows.size(); + } + + @Override + public boolean first() throws SQLException { + this.currentRow = 0; + return this.currentRow < this.rows.size(); + } + + @Override + public boolean last() throws SQLException { + this.currentRow = this.rows.size() - 1; + return this.currentRow >= 0; + } + + @Override + public int getRow() throws SQLException { + return this.currentRow + 1; + } + + @Override + public boolean absolute(final int row) throws SQLException { + if (row > 0) { + this.currentRow = row - 1; + } else { + this.currentRow = this.rows.size() + row; + } + return this.currentRow >= 0 && this.currentRow < this.rows.size(); + } + + @Override + public boolean relative(final int rows) throws SQLException { + this.currentRow += rows; + return this.currentRow >= 0 && this.currentRow < this.rows.size(); + } + + @Override + public boolean previous() throws SQLException { + this.currentRow--; + return this.currentRow >= 0 && this.currentRow < this.rows.size(); + } + + @Override + public void setFetchDirection(final int direction) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchDirection() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void setFetchSize(final int rows) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getFetchSize() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getType() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getConcurrency() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowUpdated() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowInserted() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean rowDeleted() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final int columnIndex, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final int columnIndex, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final int columnIndex, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final int columnIndex, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final int columnIndex, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final int columnIndex, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final int columnIndex, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final int columnIndex, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final int columnIndex, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final int columnIndex, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final int columnIndex, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final int columnIndex, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final int columnIndex, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final int length) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x, final int scaleOrLength) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final int columnIndex, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNull(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBoolean(final String columnLabel, final boolean x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateByte(final String columnLabel, final byte x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateShort(final String columnLabel, final short x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateInt(final String columnLabel, final int x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateLong(final String columnLabel, final long x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateFloat(final String columnLabel, final float x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDouble(final String columnLabel, final double x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBigDecimal(final String columnLabel, final BigDecimal x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateString(final String columnLabel, final String x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBytes(final String columnLabel, final byte[] x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateDate(final String columnLabel, final Date x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTime(final String columnLabel, final Time x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateTimestamp(final String columnLabel, final Timestamp x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final int length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x, final int scaleOrLength) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateObject(final String columnLabel, final Object x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void insertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void refreshRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void cancelRowUpdates() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToInsertRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void moveToCurrentRow() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Statement getStatement() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final int columnIndex, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Object getObject(final String columnLabel, final Map> map) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Ref getRef(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Blob getBlob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Clob getClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Array getArray(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public URL getURL(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public URL getURL(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(final int columnIndex, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRef(final String columnLabel, final Ref x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final Blob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Clob x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final int columnIndex, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateArray(final String columnLabel, final Array x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public RowId getRowId(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(final int columnIndex, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateRowId(final String columnLabel, final RowId x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getHoldability() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isClosed() throws SQLException { + return false; + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final int columnIndex, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNString(final String columnLabel, final String nString) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final int columnIndex, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void updateNClob(final String columnLabel, final NClob nClob) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + @SuppressWarnings("checkstyle:MethodName") + public NClob getNClob(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public NClob getNClob(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public SQLXML getSQLXML(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(final int columnIndex, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getNString(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(final int columnIndex) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getNCharacterStream(final String columnLabel) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader, final long length) + throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final int columnIndex, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final int columnIndex, final Reader x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateAsciiStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBinaryStream(final String columnLabel, final InputStream x) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateCharacterStream(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final int columnIndex, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateBlob(final String columnLabel, final InputStream inputStream) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final int columnIndex, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public void updateNClob(final String columnLabel, final Reader reader) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getObject(final int columnIndex, final Class type) throws SQLException { + return type.cast(getObject(columnIndex)); + } + + @Override + public T getObject(final String columnLabel, final Class type) throws SQLException { + return type.cast(getObject(columnLabel)); + } + + @Override + public T unwrap(final Class iface) throws SQLException { + return iface == ResultSet.class ? iface.cast(this) : null; + } + + @Override + public boolean isWrapperFor(final Class iface) throws SQLException { + return iface != null && iface.isAssignableFrom(this.getClass()); + } +} diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 3d26591e6..759fd4047 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -58,6 +58,7 @@ dependencies { testImplementation("io.opentelemetry:opentelemetry-sdk:1.42.1") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.43.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.44.1") + testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") From 9ab2f67851941a3b9d3cb895f8420e7e957aa95c Mon Sep 17 00:00:00 2001 From: Nihal Mehta Date: Tue, 6 May 2025 15:06:18 -0700 Subject: [PATCH 02/24] Add Cache Connection pool support Signed-off-by: Nihal Mehta --- examples/AWSDriverExample/build.gradle.kts | 1 + wrapper/build.gradle.kts | 2 + .../amazon/jdbc/util/CacheConnection.java | 139 ++++++++++++++++-- wrapper/src/test/build.gradle.kts | 1 + 4 files changed, 127 insertions(+), 16 deletions(-) diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index d43e570aa..ae4b9ab61 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { implementation("org.jsoup:jsoup:1.21.1") implementation("com.mchange:c3p0:0.11.0") implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") } tasks.withType { diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 107db7b80..f0e201e0a 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -51,6 +51,7 @@ dependencies { optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") compileOnly("io.lettuce:lettuce-core:6.6.0.RELEASE") + compileOnly("org.apache.commons:commons-pool2:2.11.1") compileOnly("org.checkerframework:checker-qual:3.49.5") compileOnly("com.mysql:mysql-connector-j:9.4.0") compileOnly("org.postgresql:postgresql:42.7.7") @@ -104,6 +105,7 @@ dependencies { testImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") testImplementation("io.opentelemetry:opentelemetry-exporter-otlp:1.52.0") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.jsoup:jsoup:1.21.1") testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java index 6b53bfe67..e199ce2b6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java @@ -10,14 +10,27 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Properties; +import io.lettuce.core.support.ConnectionPoolSupport; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.commons.pool2.impl.GenericObjectPoolConfig; // Abstraction layer on top of a connection to a remote cache server public class CacheConnection { - // TODO: support connection pools to the remote cache server for read and write - private StatefulRedisConnection connection = null; + // Adding support for read and write connection pools to the remote cache server + private static volatile GenericObjectPool> readConnectionPool; + private static volatile GenericObjectPool> writeConnectionPool; + private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); private final String cacheServerAddr; private MessageDigest msgHashDigest = null; + private static final int DEFAULT_POOL_SIZE = 10; + private static final int DEFAULT_POOL_MAX_IDLE = 10; + private static final int DEFAULT_POOL_MIN_IDLE = 0; + private static final long DEFAULT_MAX_BORROW_WAIT_MS = 50; + + private static final Object READ_LOCK = new Object(); + private static final Object WRITE_LOCK = new Object(); + private static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = new AwsWrapperProperty( "cacheEndpointAddrRw", @@ -28,7 +41,12 @@ public CacheConnection(final Properties properties) { this.cacheServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); } - private void initializeCacheConnectionIfNeeded() { + /* Here we check if we need to initialise connection pool for read or write to cache. + With isRead we check if we need to initialise connection pool for read or write to cache. + If isRead is true, we initialise connection pool for read. + If isRead is false, we initialise connection pool for write. + */ + private void initializeCacheConnectionIfNeeded(boolean isRead) { if (StringUtils.isNullOrEmpty(cacheServerAddr)) return; // Initialize the message digest if (msgHashDigest == null) { @@ -38,17 +56,56 @@ private void initializeCacheConnectionIfNeeded() { throw new RuntimeException("SHA-384 not supported", e); } } - // Create a stateful redis connection with TLS enabled - if (connection == null) { - System.out.println("Now we are creating a new Redis connection......"); - ClientResources resources = ClientResources.builder().build(); - final RedisURI redisUriCluster = RedisURI.Builder.redis(cacheServerAddr). + + GenericObjectPool> cacheConnectionPool = + isRead ? readConnectionPool : writeConnectionPool; + Object lock = isRead ? READ_LOCK : WRITE_LOCK; + + if (cacheConnectionPool == null) { + synchronized (lock) { + if ((isRead && readConnectionPool == null) || (!isRead && writeConnectionPool == null)) { + createConnectionPool(isRead); + } + } + } + } + + private void createConnectionPool(boolean isRead) { + ClientResources resources = ClientResources.builder().build(); + try { + RedisURI redisUriCluster = RedisURI.Builder.redis(cacheServerAddr). withPort(6379).withSsl(true).withVerifyPeer(false).build(); - RedisClient clusterClient = RedisClient.create(resources, redisUriCluster); - connection = clusterClient.connect(new ByteArrayCodec()); + + RedisClient client = RedisClient.create(resources, redisUriCluster); + GenericObjectPool> pool = + ConnectionPoolSupport.createGenericObjectPool( + () -> client.connect(new ByteArrayCodec()), + poolConfig + ); + + if (isRead) { + readConnectionPool = pool; + } else { + writeConnectionPool = pool; + } + + } catch (Exception e) { + String poolType = isRead ? "read" : "write"; + String errorMsg = String.format("Failed to create Cache %s connection pool", poolType); + System.err.println(errorMsg + ": " + e.getMessage()); + throw new RuntimeException(errorMsg, e); } } + private static GenericObjectPoolConfig> createPoolConfig() { + GenericObjectPoolConfig> poolConfig = new GenericObjectPoolConfig<>(); + poolConfig.setMaxTotal(DEFAULT_POOL_SIZE); + poolConfig.setMaxIdle(DEFAULT_POOL_MAX_IDLE); + poolConfig.setMinIdle(DEFAULT_POOL_MIN_IDLE); + poolConfig.setMaxWaitMillis(DEFAULT_MAX_BORROW_WAIT_MS); + return poolConfig; + } + // Get the hash digest of the given key. private byte[] computeHashDigest(byte[] key) { msgHashDigest.update(key); @@ -56,14 +113,64 @@ private byte[] computeHashDigest(byte[] key) { } public byte[] readFromCache(String key) { - initializeCacheConnectionIfNeeded(); - // TODO: get a connection from the read connection pool - return connection.sync().get(computeHashDigest(key.getBytes(StandardCharsets.UTF_8))); + boolean isBroken = false; + StatefulRedisConnection conn = null; + initializeCacheConnectionIfNeeded(true); + // get a connection from the read connection pool + try { + conn = readConnectionPool.borrowObject(); + return conn.sync().get(computeHashDigest(key.getBytes(StandardCharsets.UTF_8))); + } catch (Exception e) { + if (conn != null) { + isBroken = true; + } + System.err.println("Failed to read from cache: " + e.getMessage()); + return null; + } finally { + if (conn != null && readConnectionPool != null) { + try { + this.returnConnectionBackToPool(conn, isBroken, true); + } catch (Exception ex) { + System.err.println("Error closing read connection: " + ex.getMessage()); + } + } + } } public void writeToCache(String key, byte[] value) { - initializeCacheConnectionIfNeeded(); - // TODO: get a connection from the write connection pool - connection.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), 300, value); + boolean isBroken = false; + initializeCacheConnectionIfNeeded(false); + // get a connection from the write connection pool + StatefulRedisConnection conn = null; + try { + conn = writeConnectionPool.borrowObject(); + conn.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), 300, value); + } catch (Exception e) { + if (conn != null){ + isBroken = true; + } + System.err.println("Failed to write to cache: " + e.getMessage()); + } finally { + if (conn != null && writeConnectionPool != null) { + try { + this.returnConnectionBackToPool(conn, isBroken, false); + } catch (Exception ex) { + System.err.println("Error closing write connection: " + ex.getMessage()); + } + } + } + } + + private void returnConnectionBackToPool(StatefulRedisConnection connection, boolean isConnectionBroken, boolean isRead) { + GenericObjectPool> pool = isRead ? readConnectionPool : writeConnectionPool; + if (isConnectionBroken) { + try { + pool.invalidateObject(connection); + } catch (Exception e) { + throw new RuntimeException("Could not invalidate connection for the pool", e); + } + } else { + pool.returnObject(connection); + } } } diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 759fd4047..3e2f17251 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -51,6 +51,7 @@ dependencies { testImplementation("org.testcontainers:mariadb:1.20.4") testImplementation("org.testcontainers:junit-jupiter:1.20.4") testImplementation("org.testcontainers:toxiproxy:1.20.4") + testImplementation("org.apache.commons:commons-pool2:2.11.1") testImplementation("org.apache.poi:poi-ooxml:5.3.0") testImplementation("org.slf4j:slf4j-simple:2.0.13") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.17.1") From cde3165bbf6fb6a77b94d2f42e4e73f34a78f8d6 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Thu, 8 May 2025 11:46:05 -0700 Subject: [PATCH 03/24] Determine SQL query cacheability via query hints. A query hint is defined as a comment that prefixes the actual SQL query string. It can be in a case-insensitive form of "/* cacheTTL=60s */" to indicate that the query should be cached with 60 seconds of TTL. Or it can indicate the query should not be cached via "/* no cache */". Allow reading cache query result from replica in cluster mode enabled Redis/Valkey, and support cluster mode disabled setting. Re-factored all the caching logic into its own directory, and removed unnecessary code and logging. Fix several issues with caching results: - Handle failure in cache connection initialization as a cache miss - Support getString(), getInt(), getDouble(), getFloat(), getBigDecimal(), getBoolean(), getDate(), getTime(), getTimestamp(), getShort() and getByte() APIs for CachedResultSet - Fix a bug with converting ResultSet into CachedResultSet, and populate basic info for ResultSetMetaData. --- examples/AWSDriverExample/build.gradle.kts | 1 + .../amazon/PgConnectionWithCacheExample.java | 21 +- wrapper/build.gradle.kts | 2 + .../jdbc/ConnectionPluginChainBuilder.java | 6 +- .../amazon/jdbc/ConnectionPluginManager.java | 4 +- .../java/software/amazon/jdbc/Driver.java | 2 +- .../cache}/CacheConnection.java | 81 +++-- .../cache}/CachedResultSet.java | 282 ++++++++++++--- .../plugin/cache/CachedResultSetMetaData.java | 148 ++++++++ .../DataCacheConnectionPlugin.java | 4 +- .../DataCacheConnectionPluginFactory.java | 2 +- .../{ => cache}/DataRemoteCachePlugin.java | 160 +++++---- .../DataRemoteCachePluginFactory.java | 2 +- wrapper/src/test/build.gradle.kts | 1 + .../container/tests/DataCachePluginTests.java | 4 +- .../plugin/cache/CachedResultSetTest.java | 328 ++++++++++++++++++ .../DataCacheConnectionPluginTest.java | 2 +- .../cache/DataRemoteCachePluginTest.java | 216 ++++++++++++ 18 files changed, 1103 insertions(+), 163 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/{util => plugin/cache}/CacheConnection.java (63%) rename wrapper/src/main/java/software/amazon/jdbc/{util => plugin/cache}/CachedResultSet.java (74%) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java rename wrapper/src/main/java/software/amazon/jdbc/plugin/{ => cache}/DataCacheConnectionPlugin.java (98%) rename wrapper/src/main/java/software/amazon/jdbc/plugin/{ => cache}/DataCacheConnectionPluginFactory.java (96%) rename wrapper/src/main/java/software/amazon/jdbc/plugin/{ => cache}/DataRemoteCachePlugin.java (56%) rename wrapper/src/main/java/software/amazon/jdbc/plugin/{ => cache}/DataRemoteCachePluginFactory.java (96%) create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java rename wrapper/src/test/java/software/amazon/jdbc/plugin/{ => cache}/DataCacheConnectionPluginTest.java (99%) create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index ae4b9ab61..1115a5128 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -22,6 +22,7 @@ dependencies { implementation("software.amazon.awssdk:secretsmanager:2.33.5") implementation("software.amazon.awssdk:sts:2.33.5") implementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") implementation(project(":aws-advanced-jdbc-wrapper")) implementation("io.opentelemetry:opentelemetry-api:1.52.0") implementation("io.opentelemetry:opentelemetry-sdk:1.51.0") diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java index 99c40dc45..9ba0a9b68 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java @@ -6,9 +6,10 @@ public class PgConnectionWithCacheExample { private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:5432/postgres"; - private static final String CACHE_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com"; + private static final String CACHE_RW_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:6379"; + private static final String CACHE_RO_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:6380"; private static final String USERNAME = "postgres"; - private static final String PASSWORD = "adminadmin"; + private static final String PASSWORD = "admin"; public static void main(String[] args) throws SQLException { final Properties properties = new Properties(); @@ -19,16 +20,20 @@ public static void main(String[] args) throws SQLException { // Configuring connection properties for the JDBC Wrapper. properties.setProperty("wrapperPlugins", "dataRemoteCache"); - properties.setProperty("cacheEndpointAddrRw", CACHE_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); properties.setProperty("wrapperLogUnclosedConnections", "true"); String queryStr = "select * from cinemas"; + String queryStr2 = "SELECT * from cinemas"; for (int i = 0 ; i < 5; i++) { - // Create a new database connection and issue a query to it - try (Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); - Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery(queryStr)) { - System.out.println("Executed the SQL query with result set: " + rs.toString()); + // Create a new database connection and issue queries to it + try { + Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery(queryStr); + ResultSet rs2 = stmt.executeQuery(queryStr2); + System.out.println("Executed the SQL query with result sets: " + rs.toString() + " and " + rs2.toString()); Thread.sleep(2000); } catch (InterruptedException e) { throw new RuntimeException(e); diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index f0e201e0a..f676949ad 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -50,6 +50,7 @@ dependencies { optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") + compileOnly("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") compileOnly("io.lettuce:lettuce-core:6.6.0.RELEASE") compileOnly("org.apache.commons:commons-pool2:2.11.1") compileOnly("org.checkerframework:checker-qual:3.49.5") @@ -99,6 +100,7 @@ dependencies { testImplementation("org.apache.poi:poi-ooxml:5.4.1") testImplementation("org.slf4j:slf4j-simple:2.0.17") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("io.opentelemetry:opentelemetry-api:1.52.0") diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 7548a4a7c..10209d291 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -33,8 +33,8 @@ import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPluginFactory; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPluginFactory; import software.amazon.jdbc.plugin.ConnectTimeConnectionPluginFactory; -import software.amazon.jdbc.plugin.DataCacheConnectionPluginFactory; -import software.amazon.jdbc.plugin.DataRemoteCachePluginFactory; +import software.amazon.jdbc.plugin.cache.DataCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePluginFactory; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPluginFactory; @@ -70,7 +70,7 @@ public class ConnectionPluginChainBuilder { put("executionTime", new ExecutionTimeConnectionPluginFactory()); put("logQuery", new LogQueryConnectionPluginFactory()); put("dataCache", new DataCacheConnectionPluginFactory()); - put("dataRemoteCache", DataRemoteCachePluginFactory.class); + put("dataRemoteCache", new DataRemoteCachePluginFactory()); put("customEndpoint", new CustomEndpointPluginFactory()); put("efm", new HostMonitoringConnectionPluginFactory()); put("efm2", new software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPluginFactory()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index da7a12742..16b1f450b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -33,8 +33,8 @@ import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; -import software.amazon.jdbc.plugin.DataRemoteCachePlugin; +import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataRemoteCachePlugin; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index 7d59e83ff..673ced102 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; import software.amazon.jdbc.plugin.AwsSecretsManagerCacheHolder; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; import software.amazon.jdbc.plugin.OpenedConnectionTracker; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java similarity index 63% rename from wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index e199ce2b6..4ea434f94 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -1,6 +1,7 @@ -package software.amazon.jdbc.util; +package software.amazon.jdbc.plugin.cache; import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisCommandExecutionException; import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.codec.ByteArrayCodec; @@ -9,18 +10,23 @@ import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.time.Duration; import java.util.Properties; +import java.util.logging.Logger; import io.lettuce.core.support.ConnectionPoolSupport; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import software.amazon.jdbc.util.StringUtils; // Abstraction layer on top of a connection to a remote cache server public class CacheConnection { + private static final Logger LOGGER = Logger.getLogger(CacheConnection.class.getName()); // Adding support for read and write connection pools to the remote cache server private static volatile GenericObjectPool> readConnectionPool; private static volatile GenericObjectPool> writeConnectionPool; private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); - private final String cacheServerAddr; + private final String cacheRwServerAddr; // read-write cache server + private final String cacheRoServerAddr; // read-only cache server private MessageDigest msgHashDigest = null; private static final int DEFAULT_POOL_SIZE = 10; @@ -31,14 +37,21 @@ public class CacheConnection { private static final Object READ_LOCK = new Object(); private static final Object WRITE_LOCK = new Object(); - private static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = + protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = new AwsWrapperProperty( "cacheEndpointAddrRw", null, - "The cache server endpoint address."); + "The cache read-write server endpoint address."); + + private static final AwsWrapperProperty CACHE_RO_ENDPOINT_ADDR = + new AwsWrapperProperty( + "cacheEndpointAddrRo", + null, + "The cache read-only server endpoint address."); public CacheConnection(final Properties properties) { - this.cacheServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); + this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); + this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); } /* Here we check if we need to initialise connection pool for read or write to cache. @@ -47,7 +60,7 @@ public CacheConnection(final Properties properties) { If isRead is false, we initialise connection pool for write. */ private void initializeCacheConnectionIfNeeded(boolean isRead) { - if (StringUtils.isNullOrEmpty(cacheServerAddr)) return; + if (StringUtils.isNullOrEmpty(cacheRwServerAddr)) return; // Initialize the message digest if (msgHashDigest == null) { try { @@ -59,9 +72,8 @@ private void initializeCacheConnectionIfNeeded(boolean isRead) { GenericObjectPool> cacheConnectionPool = isRead ? readConnectionPool : writeConnectionPool; - Object lock = isRead ? READ_LOCK : WRITE_LOCK; - if (cacheConnectionPool == null) { + Object lock = isRead ? READ_LOCK : WRITE_LOCK; synchronized (lock) { if ((isRead && readConnectionPool == null) || (!isRead && writeConnectionPool == null)) { createConnectionPool(isRead); @@ -73,13 +85,37 @@ private void initializeCacheConnectionIfNeeded(boolean isRead) { private void createConnectionPool(boolean isRead) { ClientResources resources = ClientResources.builder().build(); try { - RedisURI redisUriCluster = RedisURI.Builder.redis(cacheServerAddr). - withPort(6379).withSsl(true).withVerifyPeer(false).build(); + // cache server addr string is in the format ":" + String serverAddr = cacheRwServerAddr; + // If read-only server is specified, use it for the read-only connections + if (isRead && !StringUtils.isNullOrEmpty(cacheRoServerAddr)) { + serverAddr = cacheRoServerAddr; + } + String[] hostnameAndPort = serverAddr.split(":"); + RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0]) + .withPort(Integer.parseInt(hostnameAndPort[1])) + .withSsl(true).withVerifyPeer(false).build(); RedisClient client = RedisClient.create(resources, redisUriCluster); GenericObjectPool> pool = ConnectionPoolSupport.createGenericObjectPool( - () -> client.connect(new ByteArrayCodec()), + () -> { + StatefulRedisConnection connection = client.connect(new ByteArrayCodec()); + // In cluster mode, we need to send READONLY command to the server for reading from replica. + // Note: we gracefully ignore ERR reply to support non cluster mode. + if (isRead) { + try { + connection.sync().readOnly(); + } catch (RedisCommandExecutionException e) { + if (e.getMessage().contains("ERR This instance has cluster support disabled")) { + LOGGER.fine("------ Note: this cache cluster has cluster support disabled ------"); + } else { + throw e; + } + } + } + return connection; + }, poolConfig ); @@ -92,7 +128,7 @@ private void createConnectionPool(boolean isRead) { } catch (Exception e) { String poolType = isRead ? "read" : "write"; String errorMsg = String.format("Failed to create Cache %s connection pool", poolType); - System.err.println(errorMsg + ": " + e.getMessage()); + LOGGER.warning(errorMsg + ": " + e.getMessage()); throw new RuntimeException(errorMsg, e); } } @@ -102,7 +138,7 @@ private static GenericObjectPoolConfig> poolConfig.setMaxTotal(DEFAULT_POOL_SIZE); poolConfig.setMaxIdle(DEFAULT_POOL_MAX_IDLE); poolConfig.setMinIdle(DEFAULT_POOL_MIN_IDLE); - poolConfig.setMaxWaitMillis(DEFAULT_MAX_BORROW_WAIT_MS); + poolConfig.setMaxWait(Duration.ofMillis(DEFAULT_MAX_BORROW_WAIT_MS)); return poolConfig; } @@ -115,47 +151,48 @@ private byte[] computeHashDigest(byte[] key) { public byte[] readFromCache(String key) { boolean isBroken = false; StatefulRedisConnection conn = null; - initializeCacheConnectionIfNeeded(true); // get a connection from the read connection pool try { + initializeCacheConnectionIfNeeded(true); conn = readConnectionPool.borrowObject(); return conn.sync().get(computeHashDigest(key.getBytes(StandardCharsets.UTF_8))); } catch (Exception e) { if (conn != null) { isBroken = true; } - System.err.println("Failed to read from cache: " + e.getMessage()); + LOGGER.warning("Failed to read result from cache. Treating it as a cache miss: " + e.getMessage()); return null; } finally { if (conn != null && readConnectionPool != null) { try { this.returnConnectionBackToPool(conn, isBroken, true); } catch (Exception ex) { - System.err.println("Error closing read connection: " + ex.getMessage()); + LOGGER.warning("Error closing read connection: " + ex.getMessage()); } } } } - public void writeToCache(String key, byte[] value) { + public void writeToCache(String key, byte[] value, int expiry) { boolean isBroken = false; - initializeCacheConnectionIfNeeded(false); - // get a connection from the write connection pool StatefulRedisConnection conn = null; try { + initializeCacheConnectionIfNeeded(false); + // get a connection from the write connection pool conn = writeConnectionPool.borrowObject(); - conn.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), 300, value); + // TODO: make the write to the cache to be async. + conn.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), expiry, value); } catch (Exception e) { if (conn != null){ isBroken = true; } - System.err.println("Failed to write to cache: " + e.getMessage()); + LOGGER.warning("Failed to write to cache: " + e.getMessage()); } finally { if (conn != null && writeConnectionPool != null) { try { this.returnConnectionBackToPool(conn, isBroken, false); } catch (Exception ex) { - System.err.println("Error closing write connection: " + ex.getMessage()); + LOGGER.warning("Error closing write connection: " + ex.getMessage()); } } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java similarity index 74% rename from wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index aeb2a2c2b..5df5c37a6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -1,10 +1,13 @@ -package software.amazon.jdbc.util; +package software.amazon.jdbc.plugin.cache; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.InputStream; import java.io.Reader; import java.math.BigDecimal; +import java.math.RoundingMode; import java.net.URL; import java.sql.Array; import java.sql.Blob; @@ -21,11 +24,21 @@ import java.sql.Statement; import java.sql.Time; import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalTime; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.OffsetTime; +import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; import java.util.ArrayList; -import java.util.Calendar; -import java.util.HashMap; import java.util.List; +import java.util.HashMap; import java.util.Map; +import java.util.TimeZone; +import java.util.Calendar; +import java.util.GregorianCalendar; @SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) public class CachedResultSet implements ResultSet { @@ -52,17 +65,20 @@ public Object get(final String columnName) { protected ArrayList rows; protected int currentRow; + protected ResultSetMetaData metadata; + protected static ObjectMapper mapper = new ObjectMapper(); + protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); + private static final Calendar calendarWithUserTz = new GregorianCalendar(); public CachedResultSet(final ResultSet resultSet) throws SQLException { - - final ResultSetMetaData md = resultSet.getMetaData(); - final int columns = md.getColumnCount(); + metadata = resultSet.getMetaData(); + final int columns = metadata.getColumnCount(); rows = new ArrayList<>(); while (resultSet.next()) { final CachedRow row = new CachedRow(); for (int i = 1; i <= columns; ++i) { - row.put(i, md.getColumnName(i), resultSet.getObject(i)); + row.put(i, metadata.getColumnName(i), resultSet.getObject(i)); } rows.add(row); } @@ -71,41 +87,51 @@ public CachedResultSet(final ResultSet resultSet) throws SQLException { public CachedResultSet(final List> resultList) { rows = new ArrayList<>(); + CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[resultList.get(0).size()]; + boolean fieldsInitialized = false; for (Map rowMap : resultList) { final CachedRow row = new CachedRow(); - int i = 1; - for (String columnName : rowMap.keySet()) { - row.put(i, columnName, rowMap.get(columnName)); + int i = 0; + for (Map.Entry entry : rowMap.entrySet()) { + String columnName = entry.getKey(); + if (!fieldsInitialized) { + fields[i] = new CachedResultSetMetaData.Field(columnName, columnName); + } + row.put(++i, columnName, entry.getValue()); } rows.add(row); + fieldsInitialized = true; } currentRow = -1; + metadata = new CachedResultSetMetaData(fields); } - public static String serializeIntoJsonString(ResultSet rs) throws SQLException { - ObjectMapper mapper = new ObjectMapper(); + public String serializeIntoJsonString() throws SQLException { + mapper.registerModule(new JavaTimeModule()); + // Serialize Date/LocalDateTime etc. into standard string format (i.e. ISO) + mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); + List> resultList = new ArrayList<>(); - ResultSetMetaData metaData = rs.getMetaData(); + ResultSetMetaData metaData = this.getMetaData(); int columns = metaData.getColumnCount(); - while (rs.next()) { + while (this.next()) { Map rowMap = new HashMap<>(); for (int i = 1; i <= columns; i++) { - rowMap.put(metaData.getColumnName(i), rs.getObject(i)); + rowMap.put(metaData.getColumnName(i), this.getObject(i)); } resultList.add(rowMap); } try { return mapper.writeValueAsString(resultList); } catch (JsonProcessingException e) { - throw new SQLException("Error serializing ResultSet to JSON", e); + throw new SQLException("Error serializing ResultSet to JSON: " + e.getMessage(), e); } } public static ResultSet deserializeFromJsonString(String jsonString) throws SQLException { if (jsonString == null || jsonString.isEmpty()) { return null; } try { - ObjectMapper mapper = new ObjectMapper(); List> resultList = mapper.readValue(jsonString, mapper.getTypeFactory().constructCollectionType(List.class, Map.class)); return new CachedResultSet(resultList); @@ -116,7 +142,7 @@ public static ResultSet deserializeFromJsonString(String jsonString) throws SQLE @Override public boolean next() throws SQLException { - if (rows.size() == 0 || isLast()) { + if (rows.isEmpty() || isLast()) { return false; } currentRow++; @@ -136,48 +162,64 @@ public boolean wasNull() throws SQLException { // TODO: implement all the getXXX APIs. @Override public String getString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + Object value = this.getObject(columnIndex); + if (value == null) return null; + return value.toString(); } @Override public boolean getBoolean(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) return false; + return Boolean.parseBoolean(value); } @Override public byte getByte(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + return (byte)this.getInt(columnIndex); } @Override public short getShort(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); + return Short.parseShort(value); } @Override public int getInt(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); + return Integer.parseInt(value); } @Override public long getLong(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); + return Long.parseLong(value); } @Override public float getFloat(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); + return Float.parseFloat(value); } @Override public double getDouble(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); + return Double.parseDouble(value); } @Override @Deprecated public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) return null; + return new BigDecimal(value).setScale(scale, RoundingMode.HALF_UP); } @Override @@ -185,19 +227,137 @@ public byte[] getBytes(final int columnIndex) throws SQLException { throw new UnsupportedOperationException(); } + private Timestamp convertLocalTimeToTimestamp(final LocalDateTime localTime, Calendar cal) { + long epochTimeInMillis; + if (cal != null) { + epochTimeInMillis = localTime.atZone(cal.getTimeZone().toZoneId()).toInstant().toEpochMilli(); + } else { + epochTimeInMillis = localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(); + } + return new Timestamp(epochTimeInMillis); + } + + private Timestamp parseIntoTimestamp(String timestampStr, Calendar cal) { + if (timestampStr.endsWith("Z")) { // ISO format timestamp in UTC like 2025-06-03T11:59:21.822364Z + return Timestamp.from(Instant.parse(timestampStr)); + } else if (timestampStr.contains("+") || timestampStr.contains("-")) { // Offset timestamp format like 2023-10-27T10:00:00+02:00 + try { + OffsetDateTime offsetDateTime = OffsetDateTime.parse(timestampStr); + return Timestamp.from(offsetDateTime.toInstant()); + } catch (DateTimeParseException e) { + // swallow this exception and move on with parsing + } + } + + if (timestampStr.contains(":")) { // timestamp without time zone info with HH:MM:ss info + // The timestamp string doesn't contain time zone information (not recommended for storage). We need + // to use the specified calendar for timezone. If calendar is not specified, use the local time zone. + String ts = timestampStr; + if (timestampStr.contains(" ")) { + ts = timestampStr.replace(" ", "T"); + } + // Obtains an instance of LocalDateTime from a text string that is in ISO_LOCAL_DATE_TIME format + return convertLocalTimeToTimestamp(LocalDateTime.parse(ts), cal); + } else { // timestamp without time zone info without HH:MM:ss info + return new Timestamp(Date.valueOf(timestampStr).getTime()); + } + } + + private Date convertToDate(Object dateObj, Calendar cal) { + if (dateObj == null) return null; + Timestamp timestamp; + if (dateObj instanceof Date) { + // Create and return a Timestamp from the milliseconds + return (Date)dateObj; + } else if (dateObj instanceof Timestamp) { + timestamp = (Timestamp) dateObj; + } else { + // Try to parse as Date with hour/minute/second. + // If Date parsing fails, try to parse it as Timestamp + try { + return Date.valueOf(dateObj.toString()); + } catch (IllegalArgumentException e) { + // Failed to parse the string as Date object. Try parsing it as Timestamp instead + timestamp = parseIntoTimestamp(dateObj.toString(), cal); + } + } + + // If the dateObj is not already the Date type, then the value cached is the + // epoch time in milliseconds. Here we need to de-serialize it as a long + if (cal == null) { + calendarWithUserTz.setTimeZone(defaultTimeZone); + } else { + calendarWithUserTz.setTimeZone(cal.getTimeZone()); + } + calendarWithUserTz.setTimeInMillis(timestamp.getTime()); + return new Date(calendarWithUserTz.getTimeInMillis()); + } + @Override public Date getDate(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + // The value cached is the string representation of epoch time in milliseconds + return convertToDate(this.getObject(columnIndex), null); + } + + private Time convertToTime(Object timeObj, Calendar cal) { + if (timeObj == null) return null; + Timestamp ts; + if (timeObj instanceof Time) { + return (Time) timeObj; + } else if (timeObj instanceof Timestamp) { + ts = (Timestamp) timeObj; + } else { + // Parse the time object from string. If it can't be parsed + // as a Time object, then try to parse it as Timestamp. + try { + String timeStr = timeObj.toString(); + if (timeStr.contains("Z")) { + // TODO: fix getTime with a different time zone + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ssX"); + LocalTime localTime = LocalTime.parse(timeStr, formatter); + return Time.valueOf(localTime); + } else if (timeStr.contains("+") || timeStr.contains("-")) { + LocalTime localTime = OffsetTime.parse(timeStr).toLocalTime(); + return Time.valueOf(localTime); + } else { + LocalTime localTime = LocalTime.parse(timeObj.toString()); + return Time.valueOf(localTime); + } + } catch (DateTimeParseException e) { + ts = parseIntoTimestamp(timeObj.toString(), cal); + } + } + // use the timezone in the cal (if set) to indicate proper time for "1:00:00" + // e.g. 10:00:00 in EST is 07:00:00 in local time zone (PST) + if (cal == null) { + calendarWithUserTz.setTimeZone(defaultTimeZone); + } else { + calendarWithUserTz.setTimeZone(cal.getTimeZone()); + } + calendarWithUserTz.setTimeInMillis(ts.getTime()); + return new Time(calendarWithUserTz.getTimeInMillis()); } @Override public Time getTime(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTime(this.getObject(columnIndex), null); + } + + private Timestamp convertToTimestamp(Object timestampObj, Calendar calendar) { + if (timestampObj == null) return null; + if (timestampObj instanceof Timestamp) { + return (Timestamp) timestampObj; + } else if (timestampObj instanceof LocalDateTime) { + return convertLocalTimeToTimestamp((LocalDateTime) timestampObj, calendar); + } else { + // De-serialize it from string representation + return parseIntoTimestamp(timestampObj.toString(), calendar); + } } @Override public Timestamp getTimestamp(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTimestamp(this.getObject(columnIndex), null); } @Override @@ -218,48 +378,64 @@ public InputStream getBinaryStream(final int columnIndex) throws SQLException { @Override public String getString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + Object value = this.getObject(columnLabel); + if (value == null) return null; + return value.toString(); } @Override public boolean getBoolean(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) return false; + return Boolean.parseBoolean(value); } @Override public byte getByte(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return (byte)this.getInt(columnLabel); } @Override public short getShort(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); + return Short.parseShort(value); } @Override public int getInt(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); + return Integer.parseInt(value); } @Override public long getLong(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); + return Long.parseLong(value); } @Override public float getFloat(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); + return Float.parseFloat(value); } @Override public double getDouble(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); + return Double.parseDouble(value); } @Override @Deprecated public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) return null; + return new BigDecimal(value).setScale(scale, RoundingMode.HALF_UP); } @Override @@ -269,17 +445,17 @@ public byte[] getBytes(final String columnLabel) throws SQLException { @Override public Date getDate(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return convertToDate(this.getObject(columnLabel), null); } @Override public Time getTime(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTime(this.getObject(columnLabel), null); } @Override public Timestamp getTimestamp(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTimestamp(this.getObject(columnLabel), null); } @Override @@ -315,7 +491,7 @@ public String getCursorName() throws SQLException { @Override public ResultSetMetaData getMetaData() throws SQLException { - throw new UnsupportedOperationException(); + return metadata; } @Override @@ -359,12 +535,16 @@ public Reader getCharacterStream(final String columnLabel) throws SQLException { @Override public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnIndex); + if (value == null) return null; + return new BigDecimal(value); } @Override public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + String value = this.getString(columnLabel); + if (value == null) return null; + return new BigDecimal(value); } @Override @@ -769,32 +949,32 @@ public Array getArray(final String columnLabel) throws SQLException { @Override public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToDate(this.getObject(columnIndex), cal); } @Override public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToDate(this.getObject(columnLabel), cal); } @Override public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTime(this.getObject(columnIndex), null); } @Override public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTime(this.getObject(columnLabel), cal); } @Override public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTimestamp(this.getObject(columnIndex), cal); } @Override public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { - throw new UnsupportedOperationException(); + return convertToTimestamp(this.getObject(columnLabel), cal); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java new file mode 100644 index 000000000..8819e2737 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java @@ -0,0 +1,148 @@ +package software.amazon.jdbc.plugin.cache; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; + +public class CachedResultSetMetaData implements ResultSetMetaData { + + public static class Field { + // TODO: support binary format + private final String columnLabel; + private final String columnName; + public Field(String columnLabel, String columnName) { + this.columnLabel = columnLabel; + this.columnName = columnName; + } + + public String getColumnLabel() { + return columnLabel; + } + + public String getColumnName() { + return columnName; + } + } + + protected Field[] fields; + + public CachedResultSetMetaData(Field[] fields) { + this.fields = fields; + } + + @Override + public int getColumnCount() throws SQLException { + return this.fields.length; + } + + @Override + public boolean isAutoIncrement(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCaseSensitive(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isSearchable(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCurrency(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int isNullable(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isSigned(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getColumnDisplaySize(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getColumnLabel(int column) throws SQLException { + return fields[column-1].getColumnLabel(); + } + + @Override + public String getColumnName(int column) throws SQLException { + return fields[column-1].getColumnName(); + } + + // TODO + @Override + public String getSchemaName(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getPrecision(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getScale(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + // TODO + @Override + public String getTableName(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getCatalogName(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public int getColumnType(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getColumnTypeName(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isReadOnly(int column) throws SQLException { + return true; + } + + @Override + public boolean isWritable(int column) throws SQLException { + return false; + } + + @Override + public boolean isDefinitelyWritable(int column) throws SQLException { + return false; + } + + @Override + public String getColumnClassName(int column) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java similarity index 98% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java index 1c95f953d..91a3f92e1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.sql.ResultSet; import java.sql.SQLException; @@ -31,7 +31,7 @@ import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.util.CachedResultSet; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java similarity index 96% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java index 555ed55cf..1dfbe8f9d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.util.Properties; import software.amazon.jdbc.ConnectionPlugin; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java similarity index 56% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 9c4c4f6b2..4f2be4aeb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.nio.charset.StandardCharsets; import java.sql.Connection; @@ -27,41 +27,42 @@ import java.util.Properties; import java.util.Set; import java.util.logging.Logger; -import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.states.SessionStateService; -import software.amazon.jdbc.util.CacheConnection; -import software.amazon.jdbc.util.CachedResultSet; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class DataRemoteCachePlugin extends AbstractConnectionPlugin { - - private static long connectTime = 0L; private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); - private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( Arrays.asList("Statement.executeQuery", "Statement.execute", "PreparedStatement.execute", "PreparedStatement.executeQuery", - "CallableStatement.execute", "CallableStatement.executeQuery", - "connect", "forceConnect"))); + "CallableStatement.execute", "CallableStatement.executeQuery"))); static { PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); } - private final PluginService pluginService; - private final TelemetryFactory telemetryFactory; - private final TelemetryCounter hitCounter; - private final TelemetryCounter missCounter; - private final TelemetryCounter totalCallsCounter; - private final CacheConnection cacheConnection; + private PluginService pluginService; + private TelemetryFactory telemetryFactory; + private TelemetryCounter hitCounter; + private TelemetryCounter missCounter; + private TelemetryCounter totalCallsCounter; + private CacheConnection cacheConnection; public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { + try { + Class.forName("io.lettuce.core.RedisClient"); // Lettuce dependency + Class.forName("org.apache.commons.pool2.impl.GenericObjectPool"); // Object pool dependency + Class.forName("com.fasterxml.jackson.databind.ObjectMapper"); // Jackson dependency + Class.forName("com.fasterxml.jackson.datatype.jsr310.JavaTimeModule"); // JSR310 dependency + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("DataRemoteCachePlugin.notInClassPath", new Object[] {e.getMessage()})); + } this.pluginService = pluginService; this.telemetryFactory = pluginService.getTelemetryFactory(); this.hitCounter = telemetryFactory.createCounter("remoteCache.cache.hit"); @@ -70,38 +71,14 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties this.cacheConnection = new CacheConnection(properties); } - @Override - public Set getSubscribedMethods() { - return subscribedMethods; - } - - private Connection connectHelper(JdbcCallable connectFunc) throws SQLException { - final long startTime = System.nanoTime(); - - final Connection result = connectFunc.call(); - - final long elapsedTimeNanos = System.nanoTime() - startTime; - connectTime += elapsedTimeNanos; - LOGGER.fine( - () -> Messages.get( - "DataRemoteCachePlugin.cacheConnectTime", - new Object[] {elapsedTimeNanos})); - return result; + // Used for unit testing purposes only + protected void setCacheConnection(CacheConnection conn) { + this.cacheConnection = conn; } @Override - public Connection connect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable connectFunc) throws SQLException { - System.out.println("DataRemoteCachingPlugin.connect()..."); - return this.connectHelper(connectFunc); - } - - @Override - public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable forceConnectFunc) - throws SQLException { - System.out.println("DataRemoteCachingPlugin.forceConnect()..."); - return this.connectHelper(forceConnectFunc); + public Set getSubscribedMethods() { + return subscribedMethods; } private String getCacheQueryKey(String query) { @@ -110,8 +87,7 @@ private String getCacheQueryKey(String query) { try { Connection currentConn = pluginService.getCurrentConnection(); DatabaseMetaData metadata = currentConn.getMetaData(); - SessionStateService sessionStateService = pluginService.getSessionStateService(); - System.out.println("DB driver protocol " + pluginService.getDriverProtocol() + LOGGER.finest("DB driver protocol " + pluginService.getDriverProtocol() + ", schema: " + currentConn.getSchema() + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() + ", user: " + metadata.getUserName() @@ -120,7 +96,7 @@ private String getCacheQueryKey(String query) { String[] words = {currentConn.getSchema(), metadata.getUserName(), query}; return String.join("_", words); } catch (SQLException e) { - System.out.println("Error getting session state: " + e.getMessage()); + LOGGER.warning("Error getting session state: " + e.getMessage()); return null; } } @@ -132,24 +108,60 @@ private ResultSet fetchResultSetFromCache(String queryStr) { if (cacheQueryKey == null) return null; // Treat this as a cache miss byte[] result = cacheConnection.readFromCache(cacheQueryKey); if (result == null) return null; - // Convert result into ResultSet try { return CachedResultSet.deserializeFromJsonString(new String(result, StandardCharsets.UTF_8)); } catch (Exception e) { - System.out.println("Error de-serializing cached result: " + e.getMessage()); + LOGGER.warning("Error de-serializing cached result: " + e.getMessage()); return null; // Treat this as a cache miss } } - private void cacheResultSet(String queryStr, ResultSet rs) throws SQLException { - System.out.println("Caching resultSet returned from postgres database ....... "); - String jsonValue = CachedResultSet.serializeIntoJsonString(rs); - + /** + * Cache the given ResultSet object. + * The ResultSet object passed in would be consumed to create a CacheResultSet object. It is returned + * for consumer consumption. + */ + private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) throws SQLException { // Write the resultSet into the cache as a single key String cacheQueryKey = getCacheQueryKey(queryStr); - if (cacheQueryKey == null) return; // Treat this condition as un-cacheable - cacheConnection.writeToCache(cacheQueryKey, jsonValue.getBytes(StandardCharsets.UTF_8)); + if (cacheQueryKey == null) return rs; // Treat this condition as un-cacheable + CachedResultSet crs = new CachedResultSet(rs); + String jsonValue = crs.serializeIntoJsonString(); + cacheConnection.writeToCache(cacheQueryKey, jsonValue.getBytes(StandardCharsets.UTF_8), expiry); + crs.beforeFirst(); + return crs; + } + + /** + * Determine the TTL based on an input query + * @param queryHint string. e.g. "NO CACHE", or "cacheTTL=100s" + * @return TTL in seconds to cache the query. + * null if the query is not cacheable. + */ + protected Integer getTtlForQuery(String queryHint) { + // Empty query is not cacheable + if (StringUtils.isNullOrEmpty(queryHint)) return null; + // Query longer than 16K is not cacheable + String[] tokens = queryHint.toLowerCase().split("cache"); + if (tokens.length >= 2) { + // Handle "no cache". + if (!StringUtils.isNullOrEmpty(tokens[0]) && "no".equals(tokens[0])) return null; + // Handle "cacheTTL=Xs" + if (!StringUtils.isNullOrEmpty(tokens[1]) && tokens[1].startsWith("ttl=")) { + int endIndex = tokens[1].indexOf('s'); + if (endIndex > 0) { + try { + return Integer.parseInt(tokens[1].substring(4, endIndex)); + } catch (Exception e) { + LOGGER.warning("Encountered exception when parsing Cache TTL: " + e.getMessage()); + } + } + } + } + + LOGGER.finest("Query hint " + queryHint + " indicates the query is not cacheable"); + return null; } @Override @@ -167,20 +179,31 @@ public T execute( boolean needToCache = false; final String sql = getQuery(jdbcMethodArgs); - // Try to fetch SELECT query from the cache - if (!StringUtils.isNullOrEmpty(sql) && sql.startsWith("select ")) { - result = fetchResultSetFromCache(sql); + // If the query is cacheable, we try to fetch the query result from the cache. + boolean isInTransaction = pluginService.isInTransaction(); + // Get the query hint part in front of the query itself + String mainQuery = sql; // The main part of the query with the query hint prefix trimmed + int endOfQueryHint = 0; + Integer configuredQueryTtl = null; + if ((sql.length() < 16000) && sql.startsWith("/*")) { + endOfQueryHint = sql.indexOf("*/"); + if (endOfQueryHint > 0) { + configuredQueryTtl = getTtlForQuery(sql.substring(2, endOfQueryHint).trim()); + mainQuery = sql.substring(endOfQueryHint + 2).trim(); + } + } + + // Query result can be served from the cache if it has a configured TTL value, and it is + // not executed in a transaction as a transaction typically need to return consistent results. + if (!isInTransaction && (configuredQueryTtl != null)) { + result = fetchResultSetFromCache(mainQuery); if (result == null) { - System.out.println("We got a cache MISS........."); // Cache miss. Need to fetch result from the database needToCache = true; missCounter.inc(); - LOGGER.finest( - () -> Messages.get( - "DataRemoteCachePlugin.queryResultsCached", - new Object[]{methodName, sql})); + LOGGER.finest("Got a cache miss for SQL: " + sql); } else { - System.out.println("We got a cache hit........."); + LOGGER.finest("Got a cache hit for SQL: " + sql); // Cache hit. Return the cached result hitCounter.inc(); try { @@ -199,11 +222,10 @@ public T execute( if (needToCache) { try { - cacheResultSet(sql, result); - result.beforeFirst(); + result = cacheResultSet(mainQuery, result, configuredQueryTtl); } catch (final SQLException ex) { // ignore exception - System.out.println("Encountered SQLException when caching results..."); + LOGGER.warning("Encountered SQLException when caching results: " + ex.getMessage()); } } @@ -213,7 +235,7 @@ public T execute( protected String getQuery(final Object[] jdbcMethodArgs) { // Get query from method argument if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { - return jdbcMethodArgs[0].toString(); + return jdbcMethodArgs[0].toString().trim(); } return null; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java similarity index 96% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java index efce32fa7..fb15d69c5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/DataRemoteCachePluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginFactory.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import java.util.Properties; import software.amazon.jdbc.ConnectionPlugin; diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 3e2f17251..3ed1011f9 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -55,6 +55,7 @@ dependencies { testImplementation("org.apache.poi:poi-ooxml:5.3.0") testImplementation("org.slf4j:slf4j-simple:2.0.13") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.17.1") + testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") testImplementation("io.opentelemetry:opentelemetry-sdk:1.42.1") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.43.0") diff --git a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java index dc8f6afa3..68b9ad995 100644 --- a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java +++ b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java @@ -39,8 +39,8 @@ import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin; -import software.amazon.jdbc.plugin.DataCacheConnectionPlugin.CachedResultSet; +import software.amazon.jdbc.plugin.cache.CachedResultSet; +import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; @TestMethodOrder(MethodOrderer.MethodName.class) @ExtendWith(TestDriverProvider.class) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java new file mode 100644 index 000000000..60d87b93c --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -0,0 +1,328 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; + +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.time.*; +import java.util.*; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; + +public class CachedResultSetTest { + static List> testResultList = new ArrayList<>(); + static Calendar estCalendar = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + + @BeforeAll + static void setUp() { + Map row = new HashMap<>(); + row.put("fieldInt", 1); // Integer + row.put("fieldString", "John Doe"); // String + row.put("fieldBoolean", true); + row.put("fieldByte", (byte)100); // 100 in ASCII is letter d + row.put("fieldShort", (short)55); + row.put("fieldLong", 8589934592L); // 2^33 + row.put("fieldFloat", 3.14159f); + row.put("fieldDouble", 2345.23345d); + row.put("fieldBigDecimal", new BigDecimal("15.33")); + row.put("fieldDate", Date.valueOf("2025-03-15")); + row.put("fieldTime", Time.valueOf("22:54:00")); + row.put("fieldDateTime", Timestamp.valueOf("2025-03-15 22:54:00")); + testResultList.add(row); + Map row2 = new HashMap<>(); + row2.put("fieldInt", 123456); // Integer + row2.put("fieldString", "Tony Stark"); // String + row2.put("fieldBoolean", false); + row2.put("fieldByte", (byte)70); // 100 in ASCII is letter F + row2.put("fieldShort", (short)135); + row2.put("fieldLong", -34359738368L); // -2^35 + row2.put("fieldFloat", -233.14159f); + row2.put("fieldDouble", -2344355.4543d); + row2.put("fieldBigDecimal", new BigDecimal("-12.45")); + row2.put("fieldDate", Date.valueOf("1102-01-15")); + row2.put("fieldTime", Time.valueOf("01:10:00")); + row2.put("fieldDateTime", LocalDateTime.of(1981, 3, 10, 1, 10, 20)); + testResultList.add(row2); + } + + private void verifyRow1(ResultSet rs) throws SQLException { + Map colNameToIndexMap = new HashMap(); + ResultSetMetaData rsmd = rs.getMetaData(); + for (int i = 1; i <= rsmd.getColumnCount(); i++) { + colNameToIndexMap.put(rsmd.getColumnName(i), i); + } + assertEquals(1, rs.getInt(colNameToIndexMap.get("fieldInt"))); + assertEquals("John Doe", rs.getString(colNameToIndexMap.get("fieldString"))); + assertTrue(rs.getBoolean(colNameToIndexMap.get("fieldBoolean"))); + assertEquals(100, rs.getByte(colNameToIndexMap.get("fieldByte"))); + assertEquals(55, rs.getShort(colNameToIndexMap.get("fieldShort"))); + assertEquals(8589934592L, rs.getLong(colNameToIndexMap.get("fieldLong"))); + assertEquals(3.14159f, rs.getFloat(colNameToIndexMap.get("fieldFloat")), 0); + assertEquals(2345.23345d, rs.getDouble(colNameToIndexMap.get("fieldDouble"))); + assertEquals(0, rs.getBigDecimal(colNameToIndexMap.get("fieldBigDecimal")).compareTo(new BigDecimal("15.33"))); + Date date = rs.getDate(colNameToIndexMap.get("fieldDate")); + assertEquals(1742022000000L, date.getTime()); + Time time = rs.getTime(colNameToIndexMap.get("fieldTime")); + assertEquals(111240000, time.getTime()); + Timestamp ts = rs.getTimestamp(colNameToIndexMap.get("fieldDateTime")); + assertEquals(1742104440000L, ts.getTime()); + } + + private void verifyRow2(ResultSet rs) throws SQLException { + assertEquals(123456, rs.getInt("fieldInt")); + assertEquals("Tony Stark", rs.getString("fieldString")); + assertFalse(rs.getBoolean("fieldBoolean")); + assertEquals(70, rs.getByte("fieldByte")); + assertEquals(135, rs.getShort("fieldShort")); + assertEquals(-34359738368L, rs.getLong("fieldLong")); + assertEquals(-233.14159f, rs.getFloat("fieldFloat")); + assertEquals(-2344355.4543d, rs.getDouble("fieldDouble")); + assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo(new BigDecimal("-12.45"))); + Date date = rs.getDate("fieldDate"); + assertEquals("1102-01-15", date.toString()); + Time time = rs.getTime("fieldTime"); + assertEquals("01:10:00", time.toString()); + Timestamp ts = rs.getTimestamp("fieldDateTime"); + assertTrue(ts.toString().startsWith("1981-03-10 01:10:20")); + } + + @Test + void test_create_and_verify_basic() throws Exception { + ResultSet rs = new CachedResultSet(testResultList); + verifyMetadata(rs); + verifyContent(rs); + rs.beforeFirst(); + CachedResultSet cachedRs = new CachedResultSet(rs); + verifyMetadata(cachedRs); + verifyContent(cachedRs); + } + + @Test + void test_serialize_and_deserialize_basic() throws Exception { + CachedResultSet cachedRs = new CachedResultSet(testResultList); + String serialized_data = cachedRs.serializeIntoJsonString(); + ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); + verifyContent(rs); + } + + private void verifyContent(ResultSet rs) throws SQLException { + assertTrue(rs.next()); + if (rs.getInt("fieldInt") == 1) { + verifyRow1(rs); + assertTrue(rs.next()); + verifyRow2(rs); + } else { + verifyRow2(rs); + assertTrue(rs.next()); + verifyRow1(rs); + } + assertFalse(rs.next()); + } + + private void verifyMetadata(ResultSet rs) throws SQLException { + ResultSetMetaData md = rs.getMetaData(); + List expectedCols = Arrays.asList("fieldInt", "fieldString", "fieldBoolean", "fieldByte", "fieldShort", "fieldLong", "fieldFloat", "fieldDouble", "fieldBigDecimal", "fieldDate", "fieldTime", "fieldDateTime"); + assertEquals(md.getColumnCount(), testResultList.get(0).size()); + List actualColNames = new ArrayList<>(); + List actualColLabels = new ArrayList<>(); + for (int i = 1; i <= md.getColumnCount(); i++) { + actualColNames.add(md.getColumnName(i)); + actualColLabels.add(md.getColumnLabel(i)); + } + assertTrue(actualColNames.containsAll(expectedCols)); + assertTrue(expectedCols.containsAll(actualColNames)); + assertTrue(actualColLabels.containsAll(expectedCols)); + assertTrue(expectedCols.containsAll(actualColLabels)); + } + + @Test + void test_get_timestamp() throws SQLException { + // Timestamp string that is in ISO format with time zone information in UTC + Map row = new HashMap<>(); + row.put("fieldTimestamp0", "2025-06-03T11:59:21.822364Z"); + // Timestamp string that is in ISO format with time zone information as offset + row.put("fieldTimestamp1", "2024-02-13T07:40:30.822364-05:00"); + row.put("fieldTimestamp2", "2023-10-27T10:00:00+02:00"); + // Timestamp string doesn't contain time zone information. + row.put("fieldTimestamp3", "1760-06-03T11:59:21.822364"); + row.put("fieldTimestamp4", "2020-05-04 10:06:10.822364"); + row.put("fieldTimestamp5", "2015-09-01 23:33:00"); + // Timestamp string doesn't contain time zone or HH:MM:SS information + row.put("fieldTimestamp6", "2019-03-15"); + row.put("fieldTimestamp7", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); + row.put("fieldTimestamp8", LocalDateTime.parse("2025-04-01T21:55:21.822364")); + List> testTimestamps = Collections.singletonList(row); + CachedResultSet cachedRs = new CachedResultSet(testTimestamps); + assertTrue(cachedRs.next()); + verifyTimestamps(cachedRs); + cachedRs.beforeFirst(); + String serialized_data = cachedRs.serializeIntoJsonString(); + ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); + assertTrue(rs.next()); + verifyTimestamps(rs); + } + + private void verifyTimestamps(ResultSet rs) throws SQLException { + // Verifying the timestamp with time zone information. The specified calendar doesn't matter. + Timestamp expectedTs = Timestamp.from(Instant.parse("2025-06-03T11:59:21.822364Z")); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp0").getTime()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp0", estCalendar).getTime()); + + expectedTs = Timestamp.from(OffsetDateTime.parse("2024-02-13T07:40:30.822364-05:00").toInstant()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp1").getTime()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp1", estCalendar).getTime()); + + expectedTs = Timestamp.from(OffsetDateTime.parse("2023-10-27T10:00:00+02:00").toInstant()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp2").getTime()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp2", estCalendar).getTime()); + + // Verify timestamp without time zone information. The specified calendar matters here + LocalDateTime localTime = LocalDateTime.parse("1760-06-03T11:59:21.822364"); + ZoneId estZone = ZoneId.of("America/New_York"); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp3").getTime()); + assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp3", estCalendar).getTime()); + + localTime = LocalDateTime.parse("2020-05-04T10:06:10.822364"); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp4").getTime()); + assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp4", estCalendar).getTime()); + + localTime = LocalDateTime.parse("2015-09-01T23:33:00"); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp5").getTime()); + assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp5", estCalendar).getTime()); + + localTime = LocalDateTime.parse("2019-03-15T00:00:00"); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp6").getTime()); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp6", estCalendar).getTime()); + + expectedTs = Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z")); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp7").getTime()); + assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp7", estCalendar).getTime()); + + localTime = LocalDateTime.parse("2025-04-01T21:55:21.822364"); + assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8").getTime()); + assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8", estCalendar).getTime()); + + assertNull(rs.getTimestamp("nonExistingField")); + } + + @Test + void test_parse_time() throws SQLException { + // Timestamp string that is in ISO format with time zone information in UTC + Map row = new HashMap<>(); + row.put("fieldTime0", Time.valueOf("18:45:20")); + row.put("fieldTime1", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); + row.put("fieldTime2", "10:30:00"); + row.put("fieldTime3", "11:59:21.822364"); + // Timestamp string that is in ISO format with time zone information + row.put("fieldTime4", "10:00:00Z"); + row.put("fieldTime5", "05:30:00-02:00"); + row.put("fieldTime6", "08:25:10+02:00"); + // Timestamp string doesn't contain time zone information. + row.put("fieldTime7", "2025-06-03T11:59:21.822364"); + row.put("fieldTime8", "1901-05-04 10:06:10.822364"); + row.put("fieldTime9", "2015-09-01 23:33:00"); + // Timestamp string doesn't contain time zone or HH:MM:SS information + row.put("fieldTime10", "2019-03-15"); + List> testTimes = Collections.singletonList(row); + CachedResultSet cachedRs = new CachedResultSet(testTimes); + assertTrue(cachedRs.next()); + verifyTimes(cachedRs); + cachedRs.beforeFirst(); + String serialized_data = cachedRs.serializeIntoJsonString(); + ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); + assertTrue(rs.next()); + verifyTimes(rs); + } + + private void verifyTimes(ResultSet rs) throws SQLException { + // Verifying the timestamp with time zone information. The specified calendar doesn't matter. + assertEquals("18:45:20", rs.getTime("fieldTime0").toString()); + assertEquals("18:45:20", rs.getTime("fieldTime0", estCalendar).toString()); + + // Convert from timestamp with time zone info + assertEquals("03:30:20", rs.getTime("fieldTime1").toString()); + assertEquals("03:30:20", rs.getTime("fieldTime1", estCalendar).toString()); + + // Verify timestamp without time zone information. The specified calendar matters here + assertEquals("10:30:00", rs.getTime("fieldTime2").toString()); + assertEquals("10:30:00", rs.getTime("fieldTime2", estCalendar).toString()); // Should be 07:30:00 + + assertEquals("11:59:21", rs.getTime("fieldTime3").toString()); + assertEquals("11:59:21", rs.getTime("fieldTime3", estCalendar).toString()); + + assertEquals("10:00:00", rs.getTime("fieldTime4").toString()); + assertEquals("10:00:00", rs.getTime("fieldTime4", estCalendar).toString()); + + assertEquals("05:30:00", rs.getTime("fieldTime5").toString()); + assertEquals("05:30:00", rs.getTime("fieldTime5", estCalendar).toString()); + + assertEquals("08:25:10", rs.getTime("fieldTime6").toString()); + assertEquals("08:25:10", rs.getTime("fieldTime6", estCalendar).toString()); + + assertEquals("11:59:21", rs.getTime("fieldTime7").toString()); + assertEquals("08:59:21", rs.getTime("fieldTime7", estCalendar).toString()); + + assertEquals("10:06:10", rs.getTime("fieldTime8").toString()); + assertEquals("07:06:10", rs.getTime("fieldTime8", estCalendar).toString()); + + assertEquals("23:33:00", rs.getTime("fieldTime9").toString()); + assertEquals("20:33:00", rs.getTime("fieldTime9", estCalendar).toString()); + + assertEquals("00:00:00", rs.getTime("fieldTime10").toString()); + assertEquals("00:00:00", rs.getTime("fieldTime10", estCalendar).toString()); + + assertNull(rs.getTime("nonExistingField")); + } + + @Test + void test_parse_date() throws SQLException { + Map row = new HashMap<>(); + row.put("fieldDate0", Date.valueOf("2009-09-30")); + row.put("fieldDate1", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); + row.put("fieldDate2", "2012-10-01"); + row.put("fieldDate3", "1930-03-20T05:30:20.822364Z"); + // Timestamp string doesn't contain time zone information. + row.put("fieldDate4", "2025-06-03T11:59:21.822364"); + row.put("fieldDate5", "1901-05-04 10:06:10.822364"); + row.put("fieldDate6", "2015-09-01 23:33:00"); + // Timestamp string doesn't contain time zone or HH:MM:SS information + List> testTimes = Collections.singletonList(row); + CachedResultSet cachedRs = new CachedResultSet(testTimes); + assertTrue(cachedRs.next()); + verifyDates(cachedRs); + cachedRs.beforeFirst(); + String serialized_data = cachedRs.serializeIntoJsonString(); + ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); + assertTrue(rs.next()); + verifyDates(rs); + } + + private void verifyDates(ResultSet rs) throws SQLException { + assertEquals("2009-09-30", rs.getDate("fieldDate0").toString()); + assertEquals("2009-09-30", rs.getDate("fieldDate0", estCalendar).toString()); + + assertEquals("2024-08-01", rs.getDate("fieldDate1").toString()); + assertEquals("2024-08-01", rs.getDate("fieldDate1", estCalendar).toString()); + + assertEquals("2012-10-01", rs.getDate("fieldDate2").toString()); + assertEquals("2012-10-01", rs.getDate("fieldDate2", estCalendar).toString()); + + assertEquals("1930-03-19", rs.getDate("fieldDate3").toString()); + assertEquals("1930-03-19", rs.getDate("fieldDate3", estCalendar).toString()); + + assertEquals("2025-06-03", rs.getDate("fieldDate4").toString()); + assertEquals("2025-06-03", rs.getDate("fieldDate4", estCalendar).toString()); + + assertEquals("1901-05-04", rs.getDate("fieldDate5").toString()); + assertEquals("1901-05-04", rs.getDate("fieldDate5", estCalendar).toString()); + + assertEquals("2015-09-01", rs.getDate("fieldDate6").toString()); + assertEquals("2015-09-01", rs.getDate("fieldDate6", estCalendar).toString()); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java similarity index 99% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java index 46e27337f..6b9c39f22 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/DataCacheConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package software.amazon.jdbc.plugin; +package software.amazon.jdbc.plugin.cache; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.anyString; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java new file mode 100644 index 000000000..b8315e2b4 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -0,0 +1,216 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.nio.charset.StandardCharsets; +import java.sql.*; +import java.util.Properties; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +public class DataRemoteCachePluginTest { + private static final Properties props = new Properties(); + private final String methodName = "Statement.executeQuery"; + private AutoCloseable closeable; + + private DataRemoteCachePlugin plugin; + @Mock PluginService mockPluginService; + @Mock TelemetryFactory mockTelemetryFactory; + @Mock TelemetryCounter mockHitCounter; + @Mock TelemetryCounter mockMissCounter; + @Mock TelemetryCounter mockTotalCallsCounter; + @Mock ResultSet mockResult1; + @Mock Statement mockStatement; + @Mock ResultSetMetaData mockMetaData; + @Mock Connection mockConnection; + @Mock DatabaseMetaData mockDbMetadata; + @Mock CacheConnection mockCacheConn; + @Mock JdbcCallable mockCallable; + + @BeforeEach + void setUp() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter); + when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter); + when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter); + + when(mockResult1.getMetaData()).thenReturn(mockMetaData); + when(mockMetaData.getColumnCount()).thenReturn(1); + when(mockMetaData.getColumnName(1)).thenReturn("fooName"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void test_getTTLFromQueryHint() throws Exception { + // Null and empty query string are not cacheable + assertNull(plugin.getTtlForQuery(null)); + assertNull(plugin.getTtlForQuery("")); + assertNull(plugin.getTtlForQuery(" ")); + // Some other query hint + assertNull(plugin.getTtlForQuery("/* cacheNotEnabled */ select * from T")); + // Rule set is empty. All select queries are cached with 300 seconds TTL + String selectQuery1 = "cachettl=300s"; + String selectQuery2 = " /* CACHETTL=100s */ SELECT ID from mytable2 "; + String selectQuery3 = "/*CacheTTL=35s*/select * from table3 where ID = 1 and name = 'tom'"; + // Query hints that are not cacheable + String selectQueryNoHint = "select * from table4"; + String selectQueryNoCache1 = "no cache"; + String selectQueryNoCache2 = " /* NO CACHE */ SELECT count(*) FROM (select player_id from roster where id = 1 FOR UPDATE) really_long_name_alias"; + String selectQueryNoCache3 = "/* cachettl=300 */ SELECT count(*) FROM (select player_id from roster where id = 1) really_long_name_alias"; + + // Non select queries are not cacheable + String veryShortQuery = "BEGIN"; + String insertQuery = "/* This is an insert query */ insert into mytable values (1, 2)"; + String updateQuery = "/* Update query */ Update /* Another hint */ mytable set val = 1"; + assertEquals(300, plugin.getTtlForQuery(selectQuery1)); + assertEquals(100, plugin.getTtlForQuery(selectQuery2)); + assertEquals(35, plugin.getTtlForQuery(selectQuery3)); + assertNull(plugin.getTtlForQuery(selectQueryNoHint)); + assertNull(plugin.getTtlForQuery(selectQueryNoCache1)); + assertNull(plugin.getTtlForQuery(selectQueryNoCache2)); + assertNull(plugin.getTtlForQuery(selectQueryNoCache3)); + assertNull(plugin.getTtlForQuery(veryShortQuery)); + assertNull(plugin.getTtlForQuery(insertQuery)); + assertNull(plugin.getTtlForQuery(updateQuery)); + } + + @Test + void test_execute_noCaching() throws Exception { + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"select * from mytable where ID = 2"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockPluginService).isInTransaction(); + verify(mockCallable).call(); + verify(mockTotalCallsCounter).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + + @Test + void test_execute_noCachingLongQuery() throws Exception { + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/* cacheTTL=30s */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCallable).call(); + verify(mockTotalCallsCounter).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + + @Test + void test_execute_cachingMiss() throws Exception { + // Query is not cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*CACHETTL=100s*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals(rs.getString("fooName"), "bar1"); + assertFalse(rs.next()); + verify(mockPluginService, times(2)).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockCacheConn).readFromCache("public_user_select * from A"); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache("public_user_select * from A", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 100); + verify(mockTotalCallsCounter).inc(); + verify(mockMissCounter).inc(); + } + + @Test + void test_execute_cachingHit() throws Exception { + final String cachedResult = "[{\"date\":\"2009-09-30\",\"code\":\"avata\"},{\"date\":\"2015-05-30\",\"code\":\"dracu\"}]"; + + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCacheConn.readFromCache("public_user_select * from table")).thenReturn(cachedResult.getBytes()); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /* CacheTtl=50s */select * from table"}); + + // Cached result set contains 2 rows + assertTrue(rs.next()); + assertEquals(rs.getString("date"), "2009-09-30"); + assertEquals(rs.getString("code"), "avata"); + assertTrue(rs.next()); + assertEquals(rs.getString("date"), "2015-05-30"); + assertEquals(rs.getString("code"), "dracu"); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockCacheConn).readFromCache("public_user_select * from table"); + verify(mockCallable, never()).call(); + verify(mockCacheConn, never()).writeToCache("public_user_select * from table", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 50); + verify(mockTotalCallsCounter).inc(); + verify(mockHitCounter).inc(); + } + + void compareResults(final ResultSet expected, final ResultSet actual) throws SQLException { + int i = 1; + while (expected.next() && actual.next()) { + assertEquals(expected.getObject(i), actual.getObject(i)); + i++; + } + } +} From 8c4f65d2da7ec559bbbe5696325a83e65c925f55 Mon Sep 17 00:00:00 2001 From: Nihal Mehta Date: Mon, 9 Jun 2025 12:01:08 -0700 Subject: [PATCH 04/24] Make write operation to the cache async Signed-off-by: Nihal Mehta --- .../jdbc/plugin/cache/CacheConnection.java | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 4ea434f94..44cfda6be 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -4,6 +4,7 @@ import io.lettuce.core.RedisCommandExecutionException; import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; import io.lettuce.core.resource.ClientResources; import software.amazon.jdbc.AwsWrapperProperty; @@ -174,23 +175,42 @@ public byte[] readFromCache(String key) { } public void writeToCache(String key, byte[] value, int expiry) { - boolean isBroken = false; StatefulRedisConnection conn = null; try { initializeCacheConnectionIfNeeded(false); // get a connection from the write connection pool conn = writeConnectionPool.borrowObject(); - // TODO: make the write to the cache to be async. - conn.sync().setex(computeHashDigest(key.getBytes(StandardCharsets.UTF_8)), expiry, value); + // Add support to make write to the cache to be async. + RedisAsyncCommands asyncCommands = conn.async(); + byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8)); + + StatefulRedisConnection finalConn = conn; + asyncCommands.setex(keyHash, expiry, value) + .whenComplete((result, exception) -> { + if (exception != null) { + LOGGER.warning("Failed to write to cache: " + exception.getMessage()); + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(finalConn, true, false); + } catch (Exception ex) { + LOGGER.warning("Error returning broken write connection back to pool: " + ex.getMessage()); + } + } + } else { + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(finalConn, false, false); + } catch (Exception ex) { + LOGGER.warning("Error returning write connection back to pool: " + ex.getMessage()); + } + } + } + }); } catch (Exception e) { - if (conn != null){ - isBroken = true; - } LOGGER.warning("Failed to write to cache: " + e.getMessage()); - } finally { if (conn != null && writeConnectionPool != null) { try { - this.returnConnectionBackToPool(conn, isBroken, false); + returnConnectionBackToPool(conn, true, false); } catch (Exception ex) { LOGGER.warning("Error closing write connection: " + ex.getMessage()); } From 213d89f6beaf00a907a58264678cd63ac577eec9 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Thu, 26 Jun 2025 13:50:17 -0700 Subject: [PATCH 05/24] Caching - don't cache queries that are part of a multi-statement transaction. Add unit test for CacheConnection logic. --- .../jdbc/plugin/cache/CacheConnection.java | 56 +++++---- .../plugin/cache/CacheConnectionTest.java | 110 ++++++++++++++++++ .../cache/DataRemoteCachePluginTest.java | 27 ++++- 3 files changed, 164 insertions(+), 29 deletions(-) create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 44cfda6be..589b87d3b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -3,6 +3,7 @@ import io.lettuce.core.RedisClient; import io.lettuce.core.RedisCommandExecutionException; import io.lettuce.core.RedisURI; +import io.lettuce.core.SetArgs; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; @@ -174,38 +175,40 @@ public byte[] readFromCache(String key) { } } + protected void handleCompletedCacheWrite(StatefulRedisConnection conn, Throwable ex) { + // Note: this callback upon completion of cache write is on a different thread + if (ex != null) { + LOGGER.warning("Failed to write to cache: " + ex.getMessage()); + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, true, false); + } catch (Exception e) { + LOGGER.warning("Error returning broken write connection back to pool: " + e.getMessage()); + } + } + } else { + if (writeConnectionPool != null) { + try { + returnConnectionBackToPool(conn, false, false); + } catch (Exception e) { + LOGGER.warning("Error returning write connection back to pool: " + e.getMessage()); + } + } + } + } + public void writeToCache(String key, byte[] value, int expiry) { StatefulRedisConnection conn = null; try { initializeCacheConnectionIfNeeded(false); // get a connection from the write connection pool conn = writeConnectionPool.borrowObject(); - // Add support to make write to the cache to be async. + // Write to the cache is async. RedisAsyncCommands asyncCommands = conn.async(); byte[] keyHash = computeHashDigest(key.getBytes(StandardCharsets.UTF_8)); - StatefulRedisConnection finalConn = conn; - asyncCommands.setex(keyHash, expiry, value) - .whenComplete((result, exception) -> { - if (exception != null) { - LOGGER.warning("Failed to write to cache: " + exception.getMessage()); - if (writeConnectionPool != null) { - try { - returnConnectionBackToPool(finalConn, true, false); - } catch (Exception ex) { - LOGGER.warning("Error returning broken write connection back to pool: " + ex.getMessage()); - } - } - } else { - if (writeConnectionPool != null) { - try { - returnConnectionBackToPool(finalConn, false, false); - } catch (Exception ex) { - LOGGER.warning("Error returning write connection back to pool: " + ex.getMessage()); - } - } - } - }); + asyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry)) + .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, exception)); } catch (Exception e) { LOGGER.warning("Failed to write to cache: " + e.getMessage()); if (conn != null && writeConnectionPool != null) { @@ -230,4 +233,11 @@ private void returnConnectionBackToPool(StatefulRedisConnection pool.returnObject(connection); } } + + // Used for unit testing only + protected void setConnectionPools(GenericObjectPool> readPool, + GenericObjectPool> writePool) { + readConnectionPool = readPool; + writeConnectionPool = writePool; + } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java new file mode 100644 index 000000000..51f7338e1 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java @@ -0,0 +1,110 @@ +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import io.lettuce.core.RedisFuture; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.api.sync.RedisCommands; +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import java.nio.charset.StandardCharsets; +import java.util.function.BiConsumer; +import java.util.Properties; + +import static org.mockito.Mockito.*; + +public class CacheConnectionTest { + @Mock GenericObjectPool> mockReadConnPool; + @Mock GenericObjectPool> mockWriteConnPool; + @Mock StatefulRedisConnection mockConnection; + @Mock RedisCommands mockSyncCommands; + @Mock RedisAsyncCommands mockAsyncCommands; + @Mock RedisFuture mockCacheResult; + private AutoCloseable closeable; + private CacheConnection cacheConnection; + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + cacheConnection = new CacheConnection(props); + cacheConnection.setConnectionPools(mockReadConnPool, mockWriteConnPool); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + } + + @Test + void test_writeToCache() throws Exception { + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenReturn(mockCacheResult); + when(mockCacheResult.whenComplete(any(BiConsumer.class))).thenReturn(null); + cacheConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockCacheResult).whenComplete(any(BiConsumer.class)); + } + + @Test + void test_writeToCacheException() throws Exception { + String key = "myQueryKey"; + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockWriteConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.async()).thenReturn(mockAsyncCommands); + when(mockAsyncCommands.set(any(), any(), any())).thenThrow(new RuntimeException("test exception")); + cacheConnection.writeToCache(key, value, 100); + verify(mockWriteConnPool).borrowObject(); + verify(mockConnection).async(); + verify(mockAsyncCommands).set(any(), any(), any()); + verify(mockWriteConnPool).invalidateObject(mockConnection); + } + + @Test + void test_handleCompletedCacheWrite() throws Exception { + cacheConnection.handleCompletedCacheWrite(mockConnection, null); + verify(mockWriteConnPool).returnObject(mockConnection); + cacheConnection.handleCompletedCacheWrite(mockConnection, new RuntimeException("test")); + verify(mockWriteConnPool).invalidateObject(mockConnection); + } + + @Test + void test_readFromCache() throws Exception { + byte[] value = "myValue".getBytes(StandardCharsets.UTF_8); + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenReturn(value); + byte[] result = cacheConnection.readFromCache("myQueryKey"); + assertEquals(value, result); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).returnObject(mockConnection); + } + + @Test + void test_readFromCacheException() throws Exception { + when(mockReadConnPool.borrowObject()).thenReturn(mockConnection); + when(mockConnection.sync()).thenReturn(mockSyncCommands); + when(mockSyncCommands.get(any())).thenThrow(new RuntimeException("test")); + assertNull(cacheConnection.readFromCache("myQueryKey")); + verify(mockReadConnPool).borrowObject(); + verify(mockConnection).sync(); + verify(mockSyncCommands).get(any()); + verify(mockReadConnPool).invalidateObject(mockConnection); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index b8315e2b4..2cbb82076 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -52,7 +52,6 @@ void setUp() throws SQLException { when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter); when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter); when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter); - when(mockResult1.getMetaData()).thenReturn(mockMetaData); when(mockMetaData.getColumnCount()).thenReturn(1); when(mockMetaData.getColumnName(1)).thenReturn("fooName"); @@ -99,6 +98,22 @@ void test_getTTLFromQueryHint() throws Exception { assertNull(plugin.getTtlForQuery(updateQuery)); } + @Test + void test_inTransaction_noCaching() throws Exception { + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockCallable.call()).thenReturn(mockResult1); + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/* cacheTtl=50s */ select * from B"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCallable).call(); + verify(mockTotalCallsCounter).inc(); + } + @Test void test_execute_noCaching() throws Exception { // Query is not cacheable @@ -158,7 +173,7 @@ void test_execute_cachingMiss() throws Exception { // Cached result set contains 1 row assertTrue(rs.next()); - assertEquals(rs.getString("fooName"), "bar1"); + assertEquals("bar1", rs.getString("fooName")); assertFalse(rs.next()); verify(mockPluginService, times(2)).getCurrentConnection(); verify(mockPluginService).isInTransaction(); @@ -191,11 +206,11 @@ void test_execute_cachingHit() throws Exception { // Cached result set contains 2 rows assertTrue(rs.next()); - assertEquals(rs.getString("date"), "2009-09-30"); - assertEquals(rs.getString("code"), "avata"); + assertEquals("2009-09-30", rs.getString("date")); + assertEquals("avata", rs.getString("code")); assertTrue(rs.next()); - assertEquals(rs.getString("date"), "2015-05-30"); - assertEquals(rs.getString("code"), "dracu"); + assertEquals("2015-05-30", rs.getString("date")); + assertEquals("dracu", rs.getString("code")); assertFalse(rs.next()); verify(mockPluginService).getCurrentConnection(); verify(mockPluginService).isInTransaction(); From 699c5e530848b0de3b71e26563a971f9ab6b9a0b Mon Sep 17 00:00:00 2001 From: Roberto Luna Rojas Date: Tue, 27 May 2025 17:05:16 -0400 Subject: [PATCH 06/24] Customize TLS connection parameter for DB remote caching --------- Co-authored-by: Roberto Luna Rojas --- .../amazon/jdbc/plugin/cache/CacheConnection.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 589b87d3b..957019a2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -51,9 +51,18 @@ public class CacheConnection { null, "The cache read-only server endpoint address."); + protected static final AwsWrapperProperty CACHE_USE_SSL = + new AwsWrapperProperty( + "cacheUseSSL", + "true", + "Whether to use SSL for cache connections."); + + private final boolean useSSL; + public CacheConnection(final Properties properties) { this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); + this.useSSL = Boolean.parseBoolean(CACHE_USE_SSL.getString(properties)); } /* Here we check if we need to initialise connection pool for read or write to cache. @@ -96,7 +105,7 @@ private void createConnectionPool(boolean isRead) { String[] hostnameAndPort = serverAddr.split(":"); RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0]) .withPort(Integer.parseInt(hostnameAndPort[1])) - .withSsl(true).withVerifyPeer(false).build(); + .withSsl(useSSL).withVerifyPeer(false).build(); RedisClient client = RedisClient.create(resources, redisUriCluster); GenericObjectPool> pool = From 51e0bd66e3727f8b269bca021ae14cd8e91cd90e Mon Sep 17 00:00:00 2001 From: Roberto Luna Rojas Date: Tue, 27 May 2025 17:05:16 -0400 Subject: [PATCH 07/24] Read properties definition from a separate config file, and rename example file to reflect database-agnostic functionality --------- Co-authored-by: Roberto Luna Rojas --- .env.example | 5 ++ ...> DatabaseConnectionWithCacheExample.java} | 20 +++-- .../java/software/amazon/util/EnvLoader.java | 83 +++++++++++++++++++ 3 files changed, 100 insertions(+), 8 deletions(-) create mode 100644 .env.example rename examples/AWSDriverExample/src/main/java/software/amazon/{PgConnectionWithCacheExample.java => DatabaseConnectionWithCacheExample.java} (63%) create mode 100644 examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..83ca4e9c7 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +DB_CONNECTION_STRING=jdbc:aws-wrapper:postgresql://localhost:5432/dbname +CACHE_RW_SERVER_ADDR=localhost:6379 +CACHE_RO_SERVER_ADDR=localhost:6380 +DB_USERNAME=postgres +DB_PASSWORD=admin diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java similarity index 63% rename from examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java rename to examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java index 9ba0a9b68..d8eaefff2 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/PgConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -1,15 +1,19 @@ package software.amazon; +import software.amazon.util.EnvLoader; import java.sql.*; import java.util.*; -public class PgConnectionWithCacheExample { +public class DatabaseConnectionWithCacheExample { - private static final String CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:5432/postgres"; - private static final String CACHE_RW_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:6379"; - private static final String CACHE_RO_SERVER_ADDR = "dev-dsk-quchen-2a-3a165932.us-west-2.amazon.com:6380"; - private static final String USERNAME = "postgres"; - private static final String PASSWORD = "admin"; + private static final EnvLoader env = new EnvLoader(); + + private static final String DB_CONNECTION_STRING = env.get("DB_CONNECTION_STRING"); + private static final String CACHE_RW_SERVER_ADDR = env.get("CACHE_RW_SERVER_ADDR"); + private static final String CACHE_RO_SERVER_ADDR = env.get("CACHE_RO_SERVER_ADDR"); + private static final String USERNAME = env.get("DB_USERNAME"); + private static final String PASSWORD = env.get("DB_PASSWORD"); + private static final String USE_SSL = env.get("USE_SSL"); public static void main(String[] args) throws SQLException { final Properties properties = new Properties(); @@ -22,6 +26,7 @@ public static void main(String[] args) throws SQLException { properties.setProperty("wrapperPlugins", "dataRemoteCache"); properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false" properties.setProperty("wrapperLogUnclosedConnections", "true"); String queryStr = "select * from cinemas"; String queryStr2 = "SELECT * from cinemas"; @@ -29,7 +34,7 @@ public static void main(String[] args) throws SQLException { for (int i = 0 ; i < 5; i++) { // Create a new database connection and issue queries to it try { - Connection conn = DriverManager.getConnection(CONNECTION_STRING, properties); + Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, properties); Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(queryStr); ResultSet rs2 = stmt.executeQuery(queryStr2); @@ -40,5 +45,4 @@ public static void main(String[] args) throws SQLException { } } } - } diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java new file mode 100644 index 000000000..7b12d91f5 --- /dev/null +++ b/examples/AWSDriverExample/src/main/java/software/amazon/util/EnvLoader.java @@ -0,0 +1,83 @@ +package software.amazon.util; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +/** + * A simple utility class to load environment variables from a .env file. + */ +public class EnvLoader { + private final Map envVars = new HashMap<>(); + + /** + * Loads environment variables from a .env file in the current directory. + */ + public EnvLoader() { + this(Paths.get(".env")); + } + + /** + * Loads environment variables from the specified file path. + * + * @param envPath Path to the .env file + */ + public EnvLoader(Path envPath) { + if (Files.exists(envPath)) { + try (BufferedReader reader = new BufferedReader(new FileReader(envPath.toFile()))) { + String line; + while ((line = reader.readLine()) != null) { + parseLine(line); + } + } catch (IOException e) { + System.err.println("Error reading .env file: " + e.getMessage()); + } + } + } + + private void parseLine(String line) { + line = line.trim(); + // Skip empty lines and comments + if (line.isEmpty() || line.startsWith("#")) { + return; + } + + // Split on the first equals sign + int delimiterPos = line.indexOf('='); + if (delimiterPos > 0) { + String key = line.substring(0, delimiterPos).trim(); + String value = line.substring(delimiterPos + 1).trim(); + + // Remove quotes if present + if ((value.startsWith("\"") && value.endsWith("\"")) || + (value.startsWith("'") && value.endsWith("'"))) { + value = value.substring(1, value.length() - 1); + } + + envVars.put(key, value); + } + } + + /** + * Gets the value of an environment variable. + * + * @param key The name of the environment variable + * @return The value of the environment variable, or null if not found + */ + public String get(String key) { + // First check the loaded .env file + String value = envVars.get(key); + + // If not found, check system environment variables + if (value == null) { + value = System.getenv(key); + } + + return value; + } +} From 98a5038a0e8249eb6aafc8728aa8fec321c5ecc6 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Fri, 4 Jul 2025 08:36:18 -0700 Subject: [PATCH 08/24] Caching - minor fixes after rebase including doing null check for telemetry. --- .../jdbc/plugin/cache/CacheConnection.java | 20 ++- .../jdbc/plugin/cache/CachedResultSet.java | 120 ++++++++++++------ .../plugin/cache/DataRemoteCachePlugin.java | 38 ++++-- .../amazon/jdbc/states/SessionState.java | 2 - ..._advanced_jdbc_wrapper_messages.properties | 4 + .../plugin/cache/CachedResultSetTest.java | 82 +++++++++++- .../cache/DataRemoteCachePluginTest.java | 71 ++++++++--- 7 files changed, 257 insertions(+), 80 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 957019a2a..03c382686 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -14,10 +14,12 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.util.Properties; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; import io.lettuce.core.support.ConnectionPoolSupport; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.util.StringUtils; // Abstraction layer on top of a connection to a remote cache server @@ -36,8 +38,8 @@ public class CacheConnection { private static final int DEFAULT_POOL_MIN_IDLE = 0; private static final long DEFAULT_MAX_BORROW_WAIT_MS = 50; - private static final Object READ_LOCK = new Object(); - private static final Object WRITE_LOCK = new Object(); + private static final ReentrantLock READ_LOCK = new ReentrantLock(); + private static final ReentrantLock WRITE_LOCK = new ReentrantLock(); protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = new AwsWrapperProperty( @@ -59,6 +61,10 @@ public class CacheConnection { private final boolean useSSL; + static { + PropertyDefinition.registerPluginProperties(CacheConnection.class); + } + public CacheConnection(final Properties properties) { this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); @@ -84,11 +90,14 @@ private void initializeCacheConnectionIfNeeded(boolean isRead) { GenericObjectPool> cacheConnectionPool = isRead ? readConnectionPool : writeConnectionPool; if (cacheConnectionPool == null) { - Object lock = isRead ? READ_LOCK : WRITE_LOCK; - synchronized (lock) { + ReentrantLock connectionPoolLock = isRead ? READ_LOCK : WRITE_LOCK; + connectionPoolLock.lock(); + try { if ((isRead && readConnectionPool == null) || (!isRead && writeConnectionPool == null)) { createConnectionPool(isRead); } + } finally { + connectionPoolLock.unlock(); } } } @@ -105,7 +114,7 @@ private void createConnectionPool(boolean isRead) { String[] hostnameAndPort = serverAddr.split(":"); RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0]) .withPort(Integer.parseInt(hostnameAndPort[1])) - .withSsl(useSSL).withVerifyPeer(false).build(); + .withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-jdbc-lettuce").build(); RedisClient client = RedisClient.create(resources, redisUriCluster); GenericObjectPool> pool = @@ -219,6 +228,7 @@ public void writeToCache(String key, byte[] value, int expiry) { asyncCommands.set(keyHash, value, SetArgs.Builder.ex(expiry)) .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, exception)); } catch (Exception e) { + // Failed to trigger the async write to the cache, return the cache connection to the pool as broken LOGGER.warning("Failed to write to cache: " + e.getMessage()); if (conn != null && writeConnectionPool != null) { try { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index 5df5c37a6..82557a41a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -40,7 +40,6 @@ import java.util.Calendar; import java.util.GregorianCalendar; -@SuppressWarnings({"RedundantThrows", "checkstyle:OverloadMethodsDeclarationOrder"}) public class CachedResultSet implements ResultSet { public static class CachedRow { @@ -65,8 +64,10 @@ public Object get(final String columnName) { protected ArrayList rows; protected int currentRow; + protected boolean wasNullFlag; protected ResultSetMetaData metadata; protected static ObjectMapper mapper = new ObjectMapper(); + protected static boolean mapperInitialized = false; protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); private static final Calendar calendarWithUserTz = new GregorianCalendar(); @@ -83,34 +84,45 @@ public CachedResultSet(final ResultSet resultSet) throws SQLException { rows.add(row); } currentRow = -1; + initializeObjectMapper(); } public CachedResultSet(final List> resultList) { rows = new ArrayList<>(); - CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[resultList.get(0).size()]; - boolean fieldsInitialized = false; - for (Map rowMap : resultList) { - final CachedRow row = new CachedRow(); - int i = 0; - for (Map.Entry entry : rowMap.entrySet()) { - String columnName = entry.getKey(); - if (!fieldsInitialized) { - fields[i] = new CachedResultSetMetaData.Field(columnName, columnName); + int numFields = resultList.isEmpty() ? 0 : resultList.get(0).size(); + CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[numFields]; + if (!resultList.isEmpty()) { + boolean fieldsInitialized = false; + for (Map rowMap : resultList) { + final CachedRow row = new CachedRow(); + int i = 0; + for (Map.Entry entry : rowMap.entrySet()) { + String columnName = entry.getKey(); + if (!fieldsInitialized) { + fields[i] = new CachedResultSetMetaData.Field(columnName, columnName); + } + row.put(++i, columnName, entry.getValue()); } - row.put(++i, columnName, entry.getValue()); + rows.add(row); + fieldsInitialized = true; } - rows.add(row); - fieldsInitialized = true; } currentRow = -1; metadata = new CachedResultSetMetaData(fields); + initializeObjectMapper(); } - public String serializeIntoJsonString() throws SQLException { + // Helper method to initialize the object mapper for serialization of objects + private void initializeObjectMapper() { + if (mapperInitialized) return; + // For serialization of Date/LocalDateTime etc, set up the time module, + // and use standard string format (i.e. ISO) mapper.registerModule(new JavaTimeModule()); - // Serialize Date/LocalDateTime etc. into standard string format (i.e. ISO) mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); + mapperInitialized = true; + } + public String serializeIntoJsonString() throws SQLException { List> resultList = new ArrayList<>(); ResultSetMetaData metaData = this.getMetaData(); int columns = metaData.getColumnCount(); @@ -156,7 +168,10 @@ public void close() throws SQLException { @Override public boolean wasNull() throws SQLException { - throw new UnsupportedOperationException(); + if (isClosed()) { + throw new SQLException("This result set is closed"); + } + return this.wasNullFlag; } // TODO: implement all the getXXX APIs. @@ -476,12 +491,12 @@ public InputStream getBinaryStream(final String columnLabel) throws SQLException @Override public SQLWarning getWarnings() throws SQLException { - throw new UnsupportedOperationException(); + return null; } @Override public void clearWarnings() throws SQLException { - throw new UnsupportedOperationException(); + // no-op } @Override @@ -494,28 +509,34 @@ public ResultSetMetaData getMetaData() throws SQLException { return metadata; } - @Override - public Object getObject(final int columnIndex) throws SQLException { + private void checkCurrentRow() throws SQLException { if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries + throw new SQLException("The current row index " + this.currentRow + " is out of range."); } + } + + @Override + public Object getObject(final int columnIndex) throws SQLException { + checkCurrentRow(); final CachedRow row = this.rows.get(this.currentRow); if (!row.columnByIndex.containsKey(columnIndex)) { - return null; // column index out of boundaries + throw new SQLException("The column index: " + columnIndex + " is out of range, number of columns: " + row.columnByIndex.size()); } - return row.columnByIndex.get(columnIndex); + Object obj = row.columnByIndex.get(columnIndex); + this.wasNullFlag = (obj == null); + return obj; } @Override public Object getObject(final String columnLabel) throws SQLException { - if (this.currentRow < 0 || this.currentRow >= this.rows.size()) { - return null; // out of boundaries - } + checkCurrentRow(); final CachedRow row = this.rows.get(this.currentRow); if (!row.columnByName.containsKey(columnLabel)) { - return null; // column name not found + throw new SQLException("The column label: " + columnLabel + " is not found"); } - return row.columnByName.get(columnLabel); + Object obj = row.columnByName.get(columnLabel); + this.wasNullFlag = (obj == null); + return obj; } @Override @@ -559,12 +580,12 @@ public boolean isAfterLast() throws SQLException { @Override public boolean isFirst() throws SQLException { - return this.currentRow == 0 && this.rows.size() > 0; + return this.currentRow == 0 && !this.rows.isEmpty(); } @Override public boolean isLast() throws SQLException { - return this.currentRow == (this.rows.size() - 1) && this.rows.size() > 0; + return this.currentRow == (this.rows.size() - 1) && !this.rows.isEmpty(); } @Override @@ -596,24 +617,49 @@ public int getRow() throws SQLException { @Override public boolean absolute(final int row) throws SQLException { - if (row > 0) { - this.currentRow = row - 1; + if (row == 0) { + this.beforeFirst(); + return false; } else { - this.currentRow = this.rows.size() + row; + int rowsSize = this.rows.size(); + if (row < 0) { + if (row < -rowsSize) { + this.beforeFirst(); + return false; + } + this.currentRow = rowsSize + row; + } else { // row > 0 + if (row > rowsSize) { + this.afterLast(); + return false; + } + this.currentRow = row - 1; + } } - return this.currentRow >= 0 && this.currentRow < this.rows.size(); + return true; } @Override public boolean relative(final int rows) throws SQLException { this.currentRow += rows; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); + if (this.currentRow < 0) { + this.beforeFirst(); + return false; + } else if (this.currentRow >= this.rows.size()) { + this.afterLast(); + return false; + } + return true; } @Override public boolean previous() throws SQLException { + if (this.currentRow < 1) { + this.beforeFirst(); + return false; + } this.currentRow--; - return this.currentRow >= 0 && this.currentRow < this.rows.size(); + return true; } @Override @@ -1054,7 +1100,7 @@ public int getHoldability() throws SQLException { @Override public boolean isClosed() throws SQLException { - return false; + return this.rows == null; } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 4f2be4aeb..9e709743e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -28,20 +28,25 @@ import java.util.Set; import java.util.logging.Logger; import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.WrapperUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( - Arrays.asList("Statement.executeQuery", "Statement.execute", - "PreparedStatement.execute", "PreparedStatement.executeQuery", - "CallableStatement.execute", "CallableStatement.executeQuery"))); + Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.STATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTE.methodName, + JdbcMethod.PREPAREDSTATEMENT_EXECUTEQUERY.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, + JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName))); static { PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); @@ -173,8 +178,6 @@ public T execute( final JdbcCallable jdbcMethodFunc, final Object[] jdbcMethodArgs) throws E { - totalCallsCounter.inc(); - ResultSet result; boolean needToCache = false; final String sql = getQuery(jdbcMethodArgs); @@ -196,23 +199,21 @@ public T execute( // Query result can be served from the cache if it has a configured TTL value, and it is // not executed in a transaction as a transaction typically need to return consistent results. if (!isInTransaction && (configuredQueryTtl != null)) { + incrCounter(totalCallsCounter); result = fetchResultSetFromCache(mainQuery); if (result == null) { // Cache miss. Need to fetch result from the database needToCache = true; - missCounter.inc(); + incrCounter(missCounter); LOGGER.finest("Got a cache miss for SQL: " + sql); } else { LOGGER.finest("Got a cache hit for SQL: " + sql); // Cache hit. Return the cached result - hitCounter.inc(); + incrCounter(hitCounter); try { result.beforeFirst(); } catch (final SQLException ex) { - if (exceptionClass.isAssignableFrom(ex.getClass())) { - throw exceptionClass.cast(ex); - } - throw new RuntimeException(ex); + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); } return resultClass.cast(result); } @@ -220,18 +221,29 @@ public T execute( result = (ResultSet) jdbcMethodFunc.call(); + // We need to cache the query result if we got a cache miss for the query result, + // or the query is cacheable and executed inside a transaction. + if (isInTransaction && (configuredQueryTtl != null)) { + needToCache = true; + } if (needToCache) { try { result = cacheResultSet(mainQuery, result, configuredQueryTtl); } catch (final SQLException ex) { - // ignore exception - LOGGER.warning("Encountered SQLException when caching results: " + ex.getMessage()); + // Log and re-throw exception + LOGGER.warning("Encountered SQLException when caching query results: " + ex.getMessage()); + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); } } return resultClass.cast(result); } + private void incrCounter(TelemetryCounter counter) { + if (counter == null) return; + counter.inc(); + } + protected String getQuery(final Object[] jdbcMethodArgs) { // Get query from method argument if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java b/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java index 708b0442f..f29f915f0 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java +++ b/wrapper/src/main/java/software/amazon/jdbc/states/SessionState.java @@ -29,8 +29,6 @@ public class SessionState { public SessionStateField transactionIsolation = new SessionStateField<>(); public SessionStateField>> typeMap = new SessionStateField<>(); - // TODO: add support for session states that affects the query result from the database - public SessionState copy() { final SessionState newSessionState = new SessionState(); newSessionState.autoCommit = this.autoCommit.copy(); diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index c5ea7c275..ff67e7fd1 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -128,6 +128,10 @@ CustomEndpointPluginFactory.awsSdkNotInClasspath=Required dependency 'AWS Java S DataCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +# Data Remote Cache Plugin +DataRemoteCachePlugin.notInClassPath=Required dependency for DataRemoteCachePlugin is not on the classpath: ''{0}'' + +# Default Connection Plugin DefaultConnectionPlugin.executingMethod=Executing method: ''{0}'' DefaultConnectionPlugin.noHostsAvailable=The default connection plugin received an empty host list from the plugin service. DefaultConnectionPlugin.unknownRoleRequested=A HostSpec with a role of HostRole.UNKNOWN was requested via getHostSpecByStrategy. The requested role must be either HostRole.WRITER or HostRole.READER diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java index 60d87b93c..61cfe478e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -1,6 +1,7 @@ package software.amazon.jdbc.plugin.cache; import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -22,6 +23,7 @@ public class CachedResultSetTest { @BeforeAll static void setUp() { Map row = new HashMap<>(); + row.put("fieldNull", null); // null row.put("fieldInt", 1); // Integer row.put("fieldString", "John Doe"); // String row.put("fieldBoolean", true); @@ -36,6 +38,7 @@ static void setUp() { row.put("fieldDateTime", Timestamp.valueOf("2025-03-15 22:54:00")); testResultList.add(row); Map row2 = new HashMap<>(); + row2.put("fieldNull", null); // null row2.put("fieldInt", 123456); // Integer row2.put("fieldString", "Tony Stark"); // String row2.put("fieldBoolean", false); @@ -58,42 +61,78 @@ private void verifyRow1(ResultSet rs) throws SQLException { colNameToIndexMap.put(rsmd.getColumnName(i), i); } assertEquals(1, rs.getInt(colNameToIndexMap.get("fieldInt"))); + assertFalse(rs.wasNull()); assertEquals("John Doe", rs.getString(colNameToIndexMap.get("fieldString"))); + assertFalse(rs.wasNull()); assertTrue(rs.getBoolean(colNameToIndexMap.get("fieldBoolean"))); + assertFalse(rs.wasNull()); assertEquals(100, rs.getByte(colNameToIndexMap.get("fieldByte"))); + assertFalse(rs.wasNull()); assertEquals(55, rs.getShort(colNameToIndexMap.get("fieldShort"))); + assertFalse(rs.wasNull()); + assertNull(rs.getObject(colNameToIndexMap.get("fieldNull"))); + assertTrue(rs.wasNull()); assertEquals(8589934592L, rs.getLong(colNameToIndexMap.get("fieldLong"))); + assertFalse(rs.wasNull()); assertEquals(3.14159f, rs.getFloat(colNameToIndexMap.get("fieldFloat")), 0); + assertFalse(rs.wasNull()); assertEquals(2345.23345d, rs.getDouble(colNameToIndexMap.get("fieldDouble"))); + assertFalse(rs.wasNull()); assertEquals(0, rs.getBigDecimal(colNameToIndexMap.get("fieldBigDecimal")).compareTo(new BigDecimal("15.33"))); + assertFalse(rs.wasNull()); + assertNull(rs.getObject(colNameToIndexMap.get("fieldNull"))); + assertTrue(rs.wasNull()); Date date = rs.getDate(colNameToIndexMap.get("fieldDate")); assertEquals(1742022000000L, date.getTime()); + assertFalse(rs.wasNull()); Time time = rs.getTime(colNameToIndexMap.get("fieldTime")); assertEquals(111240000, time.getTime()); + assertFalse(rs.wasNull()); Timestamp ts = rs.getTimestamp(colNameToIndexMap.get("fieldDateTime")); assertEquals(1742104440000L, ts.getTime()); + assertFalse(rs.wasNull()); } private void verifyRow2(ResultSet rs) throws SQLException { assertEquals(123456, rs.getInt("fieldInt")); + assertFalse(rs.wasNull()); assertEquals("Tony Stark", rs.getString("fieldString")); + assertFalse(rs.wasNull()); assertFalse(rs.getBoolean("fieldBoolean")); + assertFalse(rs.wasNull()); assertEquals(70, rs.getByte("fieldByte")); + assertFalse(rs.wasNull()); assertEquals(135, rs.getShort("fieldShort")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject("fieldNull")); + assertTrue(rs.wasNull()); assertEquals(-34359738368L, rs.getLong("fieldLong")); + assertFalse(rs.wasNull()); assertEquals(-233.14159f, rs.getFloat("fieldFloat")); + assertFalse(rs.wasNull()); assertEquals(-2344355.4543d, rs.getDouble("fieldDouble")); + assertFalse(rs.wasNull()); assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo(new BigDecimal("-12.45"))); + assertFalse(rs.wasNull()); Date date = rs.getDate("fieldDate"); assertEquals("1102-01-15", date.toString()); + assertFalse(rs.wasNull()); Time time = rs.getTime("fieldTime"); assertEquals("01:10:00", time.toString()); + assertFalse(rs.wasNull()); Timestamp ts = rs.getTimestamp("fieldDateTime"); assertTrue(ts.toString().startsWith("1981-03-10 01:10:20")); + assertFalse(rs.wasNull()); } @Test void test_create_and_verify_basic() throws Exception { + // An empty result set + ResultSet rs0 = new CachedResultSet(new ArrayList<>()); + assertFalse(rs0.next()); + ResultSetMetaData md = rs0.getMetaData(); + assertEquals(0, md.getColumnCount()); + // Result set containing data ResultSet rs = new CachedResultSet(testResultList); verifyMetadata(rs); verifyContent(rs); @@ -101,6 +140,8 @@ void test_create_and_verify_basic() throws Exception { CachedResultSet cachedRs = new CachedResultSet(rs); verifyMetadata(cachedRs); verifyContent(cachedRs); + rs.clearWarnings(); + assertNull(rs.getWarnings()); } @Test @@ -117,17 +158,39 @@ private void verifyContent(ResultSet rs) throws SQLException { verifyRow1(rs); assertTrue(rs.next()); verifyRow2(rs); + rs.previous(); + verifyRow1(rs); + verifyNonexistingField(rs); + rs.relative(1); // Advances to next row + verifyRow2(rs); + rs.absolute(2); + verifyRow2(rs); } else { verifyRow2(rs); assertTrue(rs.next()); verifyRow1(rs); + rs.previous(); + verifyRow2(rs); + verifyNonexistingField(rs); + rs.relative(1); // Advances to next row + verifyRow1(rs); + rs.absolute(2); + verifyRow1(rs); } assertFalse(rs.next()); + rs.relative(-10); + assertTrue(rs.isBeforeFirst()); + rs.relative(10); + assertTrue(rs.isAfterLast()); + rs.absolute(-10); + assertTrue(rs.isBeforeFirst()); + rs.absolute(10); + assertTrue(rs.isAfterLast()); } private void verifyMetadata(ResultSet rs) throws SQLException { ResultSetMetaData md = rs.getMetaData(); - List expectedCols = Arrays.asList("fieldInt", "fieldString", "fieldBoolean", "fieldByte", "fieldShort", "fieldLong", "fieldFloat", "fieldDouble", "fieldBigDecimal", "fieldDate", "fieldTime", "fieldDateTime"); + List expectedCols = Arrays.asList("fieldNull", "fieldInt", "fieldString", "fieldBoolean", "fieldByte", "fieldShort", "fieldLong", "fieldFloat", "fieldDouble", "fieldBigDecimal", "fieldDate", "fieldTime", "fieldDateTime"); assertEquals(md.getColumnCount(), testResultList.get(0).size()); List actualColNames = new ArrayList<>(); List actualColLabels = new ArrayList<>(); @@ -161,11 +224,22 @@ void test_get_timestamp() throws SQLException { CachedResultSet cachedRs = new CachedResultSet(testTimestamps); assertTrue(cachedRs.next()); verifyTimestamps(cachedRs); + verifyNonexistingField(cachedRs); cachedRs.beforeFirst(); String serialized_data = cachedRs.serializeIntoJsonString(); ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); assertTrue(rs.next()); verifyTimestamps(rs); + verifyNonexistingField(rs); + } + + private void verifyNonexistingField(ResultSet rs) { + try { + rs.getTimestamp("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } } private void verifyTimestamps(ResultSet rs) throws SQLException { @@ -207,8 +281,6 @@ private void verifyTimestamps(ResultSet rs) throws SQLException { localTime = LocalDateTime.parse("2025-04-01T21:55:21.822364"); assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8").getTime()); assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8", estCalendar).getTime()); - - assertNull(rs.getTimestamp("nonExistingField")); } @Test @@ -233,6 +305,7 @@ void test_parse_time() throws SQLException { CachedResultSet cachedRs = new CachedResultSet(testTimes); assertTrue(cachedRs.next()); verifyTimes(cachedRs); + verifyNonexistingField(cachedRs); cachedRs.beforeFirst(); String serialized_data = cachedRs.serializeIntoJsonString(); ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); @@ -276,8 +349,6 @@ private void verifyTimes(ResultSet rs) throws SQLException { assertEquals("00:00:00", rs.getTime("fieldTime10").toString()); assertEquals("00:00:00", rs.getTime("fieldTime10", estCalendar).toString()); - - assertNull(rs.getTime("nonExistingField")); } @Test @@ -296,6 +367,7 @@ void test_parse_date() throws SQLException { CachedResultSet cachedRs = new CachedResultSet(testTimes); assertTrue(cachedRs.next()); verifyDates(cachedRs); + verifyNonexistingField(cachedRs); cachedRs.beforeFirst(); String serialized_data = cachedRs.serializeIntoJsonString(); ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 2cbb82076..1961df445 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -98,22 +99,6 @@ void test_getTTLFromQueryHint() throws Exception { assertNull(plugin.getTtlForQuery(updateQuery)); } - @Test - void test_inTransaction_noCaching() throws Exception { - // Query is not cacheable - when(mockPluginService.isInTransaction()).thenReturn(true); - when(mockCallable.call()).thenReturn(mockResult1); - ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/* cacheTtl=50s */ select * from B"}); - - // Mock result set containing 1 row - when(mockResult1.next()).thenReturn(true, true, false, false); - when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); - compareResults(mockResult1, rs); - verify(mockCallable).call(); - verify(mockTotalCallsCounter).inc(); - } - @Test void test_execute_noCaching() throws Exception { // Query is not cacheable @@ -129,7 +114,7 @@ void test_execute_noCaching() throws Exception { compareResults(mockResult1, rs); verify(mockPluginService).isInTransaction(); verify(mockCallable).call(); - verify(mockTotalCallsCounter).inc(); + verify(mockTotalCallsCounter, never()).inc(); verify(mockHitCounter, never()).inc(); verify(mockMissCounter, never()).inc(); } @@ -148,7 +133,7 @@ void test_execute_noCachingLongQuery() throws Exception { when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); compareResults(mockResult1, rs); verify(mockCallable).call(); - verify(mockTotalCallsCounter).inc(); + verify(mockTotalCallsCounter, never()).inc(); verify(mockHitCounter, never()).inc(); verify(mockMissCounter, never()).inc(); } @@ -221,6 +206,56 @@ void test_execute_cachingHit() throws Exception { verify(mockHitCounter).inc(); } + @Test + void test_transaction_cacheQuery() throws Exception { + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/* cacheTTL=300s */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache("public_user_select * from T", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 300); + verify(mockTotalCallsCounter, never()).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + + @Test + void test_transaction_noCaching() throws Exception { + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockCallable.call()).thenReturn(mockResult1); + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"delete from mytable"}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockTotalCallsCounter, never()).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + void compareResults(final ResultSet expected, final ResultSet actual) throws SQLException { int i = 1; while (expected.next() && actual.next()) { From 90e3d05eeb686ba5f89313194631dba65259862c Mon Sep 17 00:00:00 2001 From: Nihal Mehta Date: Tue, 15 Jul 2025 14:14:12 -0700 Subject: [PATCH 09/24] Create connection pool without relying on Lettuce methods --- .../jdbc/plugin/cache/CacheConnection.java | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 03c382686..ed55a3555 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -16,9 +16,11 @@ import java.util.Properties; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; -import io.lettuce.core.support.ConnectionPoolSupport; +import org.apache.commons.pool2.BasePooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.commons.pool2.PooledObject; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.util.StringUtils; @@ -117,27 +119,29 @@ private void createConnectionPool(boolean isRead) { .withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-jdbc-lettuce").build(); RedisClient client = RedisClient.create(resources, redisUriCluster); - GenericObjectPool> pool = - ConnectionPoolSupport.createGenericObjectPool( - () -> { - StatefulRedisConnection connection = client.connect(new ByteArrayCodec()); - // In cluster mode, we need to send READONLY command to the server for reading from replica. - // Note: we gracefully ignore ERR reply to support non cluster mode. - if (isRead) { - try { - connection.sync().readOnly(); - } catch (RedisCommandExecutionException e) { - if (e.getMessage().contains("ERR This instance has cluster support disabled")) { - LOGGER.fine("------ Note: this cache cluster has cluster support disabled ------"); - } else { - throw e; - } + GenericObjectPool> pool = new GenericObjectPool<>( + new BasePooledObjectFactory>() { + public StatefulRedisConnection create() { + StatefulRedisConnection connection = client.connect(new ByteArrayCodec()); + // In cluster mode, we need to send READONLY command to the server for reading from replica. + // Note: we gracefully ignore ERR reply to support non cluster mode. + if (isRead) { + try { + connection.sync().readOnly(); + } catch (RedisCommandExecutionException e) { + if (e.getMessage().contains("ERR This instance has cluster support disabled")) { + LOGGER.fine("------ Note: this cache cluster has cluster support disabled ------"); + } else { + throw e; } } - return connection; - }, - poolConfig - ); + } + return connection; + } + public PooledObject> wrap(StatefulRedisConnection connection) { + return new DefaultPooledObject<>(connection); + } + }, poolConfig); if (isRead) { readConnectionPool = pool; From 6fe29b3ecd5f6e0adf8c5472afbc1d1bd5021390 Mon Sep 17 00:00:00 2001 From: Nihal Mehta Date: Wed, 23 Jul 2025 10:14:16 -0700 Subject: [PATCH 10/24] Add Multi-threaded concurrent environment to DBConnectionWithCacheExample program --- .../DatabaseConnectionWithCacheExample.java | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java index d8eaefff2..5b21f605f 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -3,6 +3,7 @@ import software.amazon.util.EnvLoader; import java.sql.*; import java.util.*; +import java.util.logging.Logger; public class DatabaseConnectionWithCacheExample { @@ -14,9 +15,12 @@ public class DatabaseConnectionWithCacheExample { private static final String USERNAME = env.get("DB_USERNAME"); private static final String PASSWORD = env.get("DB_PASSWORD"); private static final String USE_SSL = env.get("USE_SSL"); + private static final int THREAD_COUNT = 8; //Use 8 Threads + private static final long TEST_DURATION_MS = 16000; //Test duration for 16 seconds public static void main(String[] args) throws SQLException { final Properties properties = new Properties(); + final Logger LOGGER = Logger.getLogger(DatabaseConnectionWithCacheExample.class.getName()); // Configuring connection properties for the underlying JDBC driver. properties.setProperty("user", USERNAME); @@ -28,20 +32,35 @@ public static void main(String[] args) throws SQLException { properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false" properties.setProperty("wrapperLogUnclosedConnections", "true"); - String queryStr = "select * from cinemas"; - String queryStr2 = "SELECT * from cinemas"; + String queryStr = "/* cacheTTL=300s */ select * from cinemas"; - for (int i = 0 ; i < 5; i++) { - // Create a new database connection and issue queries to it + // Create threads for concurrent connection testing + Thread[] threads = new Thread[THREAD_COUNT]; + for (int t = 0; t < THREAD_COUNT; t++) { + // Each thread uses a single connection for multiple queries + threads[t] = new Thread(() -> { + try { + try (Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, properties)) { + long endTime = System.currentTimeMillis() + TEST_DURATION_MS; + try (Statement stmt = conn.createStatement()) { + while (System.currentTimeMillis() < endTime) { + ResultSet rs = stmt.executeQuery(queryStr); + System.out.println("Executed the SQL query with result sets: " + rs.toString()); + } + } + } + } catch (Exception e) { + LOGGER.warning("Error: " + e.getMessage()); + } + }); + threads[t].start(); + } + // Wait for all threads to complete + for (Thread thread : threads) { try { - Connection conn = DriverManager.getConnection(DB_CONNECTION_STRING, properties); - Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery(queryStr); - ResultSet rs2 = stmt.executeQuery(queryStr2); - System.out.println("Executed the SQL query with result sets: " + rs.toString() + " and " + rs2.toString()); - Thread.sleep(2000); + thread.join(); } catch (InterruptedException e) { - throw new RuntimeException(e); + LOGGER.warning("Thread interrupted: " + e.getMessage()); } } } From fa106dc4342dc0ebdbdbf29cc0d82b03eecb0cc3 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Thu, 14 Aug 2025 10:32:32 -0700 Subject: [PATCH 11/24] Properly implement CachedResultSet and modified the serialization logic to rely on standard Java object serialization instead of custom JacksonMapper logic. Namely removed custom serialization logic for several special data types that are returned by the underlying SQL driver. Handle conversion from Number for getBigDecimal() and added unit tests for it. --- examples/AWSDriverExample/build.gradle.kts | 1 - .../DatabaseConnectionWithCacheExample.java | 2 +- wrapper/build.gradle.kts | 3 - .../jdbc/plugin/cache/CachedResultSet.java | 564 +++++++------- .../plugin/cache/CachedResultSetMetaData.java | 120 +-- .../plugin/cache/DataRemoteCachePlugin.java | 17 +- wrapper/src/test/build.gradle.kts | 1 - .../plugin/cache/CachedResultSetTest.java | 703 ++++++++++-------- .../cache/DataRemoteCachePluginTest.java | 64 +- 9 files changed, 750 insertions(+), 725 deletions(-) diff --git a/examples/AWSDriverExample/build.gradle.kts b/examples/AWSDriverExample/build.gradle.kts index 1115a5128..ae4b9ab61 100644 --- a/examples/AWSDriverExample/build.gradle.kts +++ b/examples/AWSDriverExample/build.gradle.kts @@ -22,7 +22,6 @@ dependencies { implementation("software.amazon.awssdk:secretsmanager:2.33.5") implementation("software.amazon.awssdk:sts:2.33.5") implementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") - implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") implementation(project(":aws-advanced-jdbc-wrapper")) implementation("io.opentelemetry:opentelemetry-api:1.52.0") implementation("io.opentelemetry:opentelemetry-sdk:1.51.0") diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java index 5b21f605f..3441ecd4a 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -32,7 +32,7 @@ public static void main(String[] args) throws SQLException { properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false" properties.setProperty("wrapperLogUnclosedConnections", "true"); - String queryStr = "/* cacheTTL=300s */ select * from cinemas"; + String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; // Create threads for concurrent connection testing Thread[] threads = new Thread[THREAD_COUNT]; diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index f676949ad..36d1f0784 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -50,7 +50,6 @@ dependencies { optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") - compileOnly("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") compileOnly("io.lettuce:lettuce-core:6.6.0.RELEASE") compileOnly("org.apache.commons:commons-pool2:2.11.1") compileOnly("org.checkerframework:checker-qual:3.49.5") @@ -100,7 +99,6 @@ dependencies { testImplementation("org.apache.poi:poi-ooxml:5.4.1") testImplementation("org.slf4j:slf4j-simple:2.0.17") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") - testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") testImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") testImplementation("io.opentelemetry:opentelemetry-api:1.52.0") @@ -112,7 +110,6 @@ dependencies { testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") - testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.2") } repositories { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index 82557a41a..6aa792c05 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -1,13 +1,16 @@ package software.amazon.jdbc.plugin.cache; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.InputStream; +import java.io.IOException; import java.io.Reader; +import java.io.Serializable; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.math.BigDecimal; import java.math.RoundingMode; +import java.net.MalformedURLException; import java.net.URL; import java.sql.Array; import java.sql.Blob; @@ -24,137 +27,125 @@ import java.sql.Statement; import java.sql.Time; import java.sql.Timestamp; -import java.time.Instant; +import java.time.LocalDate; import java.time.LocalTime; import java.time.LocalDateTime; import java.time.OffsetDateTime; import java.time.OffsetTime; +import java.time.ZonedDateTime; import java.time.ZoneId; -import java.time.format.DateTimeFormatter; -import java.time.format.DateTimeParseException; import java.util.ArrayList; -import java.util.List; import java.util.HashMap; import java.util.Map; -import java.util.TimeZone; import java.util.Calendar; -import java.util.GregorianCalendar; public class CachedResultSet implements ResultSet { - public static class CachedRow { - protected final HashMap columnByIndex = new HashMap<>(); - protected final HashMap columnByName = new HashMap<>(); + public static class CachedRow implements Serializable { + private final Object[] rowData; - public void put(final int columnIndex, final String columnName, final Object columnValue) { - columnByIndex.put(columnIndex, columnValue); - columnByName.put(columnName, columnValue); + public CachedRow(int numColumns) { + rowData = new Object[numColumns]; } - @SuppressWarnings("unused") - public Object get(final int columnIndex) { - return columnByIndex.get(columnIndex); + public void put(final int columnIndex, final Object columnValue) throws SQLException { + if (columnIndex < 1 || columnIndex > rowData.length) { + throw new SQLException("Invalid Column Index when populating CachedRow: " + columnIndex); + } + rowData[columnIndex-1] = columnValue; } - @SuppressWarnings("unused") - public Object get(final String columnName) { - return columnByName.get(columnName); + public Object get(final int columnIndex) throws SQLException { + if (columnIndex < 1 || columnIndex > rowData.length) { + throw new SQLException("Invalid Column Index when getting CachedRow value: " + columnIndex); + } + return rowData[columnIndex - 1]; } } protected ArrayList rows; protected int currentRow; protected boolean wasNullFlag; - protected ResultSetMetaData metadata; - protected static ObjectMapper mapper = new ObjectMapper(); - protected static boolean mapperInitialized = false; - protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); - private static final Calendar calendarWithUserTz = new GregorianCalendar(); + private final CachedResultSetMetaData metadata; + protected static final ZoneId defaultTimeZoneId = ZoneId.systemDefault(); + private final HashMap columnNames; + private volatile boolean closed; public CachedResultSet(final ResultSet resultSet) throws SQLException { - metadata = resultSet.getMetaData(); - final int columns = metadata.getColumnCount(); + ResultSetMetaData srcMetadata = resultSet.getMetaData(); + final int numColumns = srcMetadata.getColumnCount(); + CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[numColumns]; + for (int i = 0; i < numColumns; i++) { + fields[i] = new CachedResultSetMetaData.Field(srcMetadata, i+1); + } + metadata = new CachedResultSetMetaData(fields); rows = new ArrayList<>(); - + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(srcMetadata.getColumnLabel(i), i); + } while (resultSet.next()) { - final CachedRow row = new CachedRow(); - for (int i = 1; i <= columns; ++i) { - row.put(i, metadata.getColumnName(i), resultSet.getObject(i)); + final CachedRow row = new CachedRow(numColumns); + for (int i = 1; i <= numColumns; ++i) { + row.put(i, resultSet.getObject(i)); } rows.add(row); } currentRow = -1; - initializeObjectMapper(); + closed = false; + wasNullFlag = false; } - public CachedResultSet(final List> resultList) { - rows = new ArrayList<>(); - int numFields = resultList.isEmpty() ? 0 : resultList.get(0).size(); - CachedResultSetMetaData.Field[] fields = new CachedResultSetMetaData.Field[numFields]; - if (!resultList.isEmpty()) { - boolean fieldsInitialized = false; - for (Map rowMap : resultList) { - final CachedRow row = new CachedRow(); - int i = 0; - for (Map.Entry entry : rowMap.entrySet()) { - String columnName = entry.getKey(); - if (!fieldsInitialized) { - fields[i] = new CachedResultSetMetaData.Field(columnName, columnName); - } - row.put(++i, columnName, entry.getValue()); - } - rows.add(row); - fieldsInitialized = true; - } + private CachedResultSet(final CachedResultSetMetaData md, final ArrayList resultRows) throws SQLException { + int numColumns = md.getColumnCount(); + this.columnNames = new HashMap<>(); + for (int i = 1; i <= numColumns; i++) { + this.columnNames.put(md.getColumnLabel(i), i); } currentRow = -1; - metadata = new CachedResultSetMetaData(fields); - initializeObjectMapper(); - } - - // Helper method to initialize the object mapper for serialization of objects - private void initializeObjectMapper() { - if (mapperInitialized) return; - // For serialization of Date/LocalDateTime etc, set up the time module, - // and use standard string format (i.e. ISO) - mapper.registerModule(new JavaTimeModule()); - mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS); - mapperInitialized = true; - } - - public String serializeIntoJsonString() throws SQLException { - List> resultList = new ArrayList<>(); - ResultSetMetaData metaData = this.getMetaData(); - int columns = metaData.getColumnCount(); - - while (this.next()) { - Map rowMap = new HashMap<>(); - for (int i = 1; i <= columns; i++) { - rowMap.put(metaData.getColumnName(i), this.getObject(i)); + rows = resultRows; + metadata = md; + closed = false; + wasNullFlag = false; + } + + public byte[] serializeIntoByteArray() throws SQLException { + // Serialize the metadata and then the rows + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream output = new ObjectOutputStream(baos)) { + output.writeObject(this.metadata); + output.writeInt(rows.size()); + while (this.next()) { + output.writeObject(rows.get(currentRow)); } - resultList.add(rowMap); - } - try { - return mapper.writeValueAsString(resultList); - } catch (JsonProcessingException e) { - throw new SQLException("Error serializing ResultSet to JSON: " + e.getMessage(), e); + output.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new SQLException("Error while serializing the ResultSet for caching: ", e); } } - public static ResultSet deserializeFromJsonString(String jsonString) throws SQLException { - if (jsonString == null || jsonString.isEmpty()) { return null; } - try { - List> resultList = mapper.readValue(jsonString, - mapper.getTypeFactory().constructCollectionType(List.class, Map.class)); - return new CachedResultSet(resultList); - } catch (JsonProcessingException e) { - throw new SQLException("Error de-serializing ResultSet from JSON", e); + public static ResultSet deserializeFromByteArray(byte[] data) throws SQLException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInputStream ois = new ObjectInputStream(bis)) { + CachedResultSetMetaData metadata = (CachedResultSetMetaData) ois.readObject(); + int numRows = ois.readInt(); + ArrayList resultRows = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + resultRows.add((CachedRow) ois.readObject()); + } + return new CachedResultSet(metadata, resultRows); + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing resultSet for caching", e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing resultSet for caching", e); } } @Override public boolean next() throws SQLException { - if (rows.isEmpty() || isLast()) { + if (rows.isEmpty()) return false; + if (this.currentRow >= rows.size() - 1) { + afterLast(); return false; } currentRow++; @@ -164,6 +155,7 @@ public boolean next() throws SQLException { @Override public void close() throws SQLException { currentRow = rows.size() - 1; + closed = true; } @Override @@ -174,205 +166,192 @@ public boolean wasNull() throws SQLException { return this.wasNullFlag; } - // TODO: implement all the getXXX APIs. @Override public String getString(final int columnIndex) throws SQLException { - Object value = this.getObject(columnIndex); + Object value = checkAndGetColumnValue(columnIndex); if (value == null) return null; return value.toString(); } @Override public boolean getBoolean(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) return false; - return Boolean.parseBoolean(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return false; + if (val instanceof Boolean) return (Boolean) val; + if (val instanceof Number) return ((Number) val).intValue() == 0; + return Boolean.parseBoolean(val.toString()); } @Override public byte getByte(final int columnIndex) throws SQLException { - return (byte)this.getInt(columnIndex); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Byte) return (Byte) val; + if (val instanceof Number) return ((Number) val).byteValue(); + return Byte.parseByte(val.toString()); } @Override public short getShort(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); - return Short.parseShort(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Short) return (Short) val; + if (val instanceof Number) return ((Number) val).shortValue(); + return Short.parseShort(val.toString()); } @Override public int getInt(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); - return Integer.parseInt(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Integer) return (Integer) val; + if (val instanceof Number) return ((Number) val).intValue(); + return Integer.parseInt(val.toString()); } @Override public long getLong(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); - return Long.parseLong(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Long) return (Long) val; + if (val instanceof Number) return ((Number) val).longValue(); + return Long.parseLong(val.toString()); } @Override public float getFloat(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); - return Float.parseFloat(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Float) return (Float) val; + if (val instanceof Number) return ((Number) val).floatValue(); + return Float.parseFloat(val.toString()); } @Override public double getDouble(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) throw new SQLException("Column index " + columnIndex + " doesn't exist"); - return Double.parseDouble(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return 0; + if (val instanceof Double) return (Double) val; + if (val instanceof Number) return ((Number) val).doubleValue(); + return Double.parseDouble(val.toString()); } @Override @Deprecated public BigDecimal getBigDecimal(final int columnIndex, final int scale) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) return null; - return new BigDecimal(value).setScale(scale, RoundingMode.HALF_UP); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return new BigDecimal(((Number)val).doubleValue()).setScale(scale, RoundingMode.HALF_UP); + return new BigDecimal(Double.parseDouble(val.toString())).setScale(scale, RoundingMode.HALF_UP); } @Override public byte[] getBytes(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); - } - - private Timestamp convertLocalTimeToTimestamp(final LocalDateTime localTime, Calendar cal) { - long epochTimeInMillis; - if (cal != null) { - epochTimeInMillis = localTime.atZone(cal.getTimeZone().toZoneId()).toInstant().toEpochMilli(); - } else { - epochTimeInMillis = localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(); - } - return new Timestamp(epochTimeInMillis); - } - - private Timestamp parseIntoTimestamp(String timestampStr, Calendar cal) { - if (timestampStr.endsWith("Z")) { // ISO format timestamp in UTC like 2025-06-03T11:59:21.822364Z - return Timestamp.from(Instant.parse(timestampStr)); - } else if (timestampStr.contains("+") || timestampStr.contains("-")) { // Offset timestamp format like 2023-10-27T10:00:00+02:00 - try { - OffsetDateTime offsetDateTime = OffsetDateTime.parse(timestampStr); - return Timestamp.from(offsetDateTime.toInstant()); - } catch (DateTimeParseException e) { - // swallow this exception and move on with parsing - } - } - - if (timestampStr.contains(":")) { // timestamp without time zone info with HH:MM:ss info - // The timestamp string doesn't contain time zone information (not recommended for storage). We need - // to use the specified calendar for timezone. If calendar is not specified, use the local time zone. - String ts = timestampStr; - if (timestampStr.contains(" ")) { - ts = timestampStr.replace(" ", "T"); - } - // Obtains an instance of LocalDateTime from a text string that is in ISO_LOCAL_DATE_TIME format - return convertLocalTimeToTimestamp(LocalDateTime.parse(ts), cal); - } else { // timestamp without time zone info without HH:MM:ss info - return new Timestamp(Date.valueOf(timestampStr).getTime()); - } + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof byte[]) return (byte[]) val; + return new byte[0]; } - private Date convertToDate(Object dateObj, Calendar cal) { + private Date convertToDate(Object dateObj, Calendar cal) throws SQLException { if (dateObj == null) return null; - Timestamp timestamp; - if (dateObj instanceof Date) { - // Create and return a Timestamp from the milliseconds - return (Date)dateObj; - } else if (dateObj instanceof Timestamp) { - timestamp = (Timestamp) dateObj; - } else { - // Try to parse as Date with hour/minute/second. - // If Date parsing fails, try to parse it as Timestamp - try { - return Date.valueOf(dateObj.toString()); - } catch (IllegalArgumentException e) { - // Failed to parse the string as Date object. Try parsing it as Timestamp instead - timestamp = parseIntoTimestamp(dateObj.toString(), cal); - } + if (dateObj instanceof Date) return (Date)dateObj; + if (dateObj instanceof Number) return new Date(((Number)dateObj).longValue()); + if (dateObj instanceof LocalDate) { + // Convert the LocalDate for the specified time zone into Date representing + // the same instant of time for the default time zone. + LocalDate localDate = (LocalDate)dateObj; + if (cal == null) return Date.valueOf(localDate); + LocalDateTime localDateTime = localDate.atStartOfDay(); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Date.valueOf(targetZonedDateTime.toLocalDate()); } - // If the dateObj is not already the Date type, then the value cached is the - // epoch time in milliseconds. Here we need to de-serialize it as a long - if (cal == null) { - calendarWithUserTz.setTimeZone(defaultTimeZone); - } else { - calendarWithUserTz.setTimeZone(cal.getTimeZone()); - } - calendarWithUserTz.setTimeInMillis(timestamp.getTime()); - return new Date(calendarWithUserTz.getTimeInMillis()); + // Note: normally the user should properly store the Date object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Date already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Date. We try to do a + // best-effort string parsing into Date with standard format "YYYY-MM-DD". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Date.valueOf(dateObj.toString()); } @Override public Date getDate(final int columnIndex) throws SQLException { // The value cached is the string representation of epoch time in milliseconds - return convertToDate(this.getObject(columnIndex), null); + return convertToDate(checkAndGetColumnValue(columnIndex), null); } - private Time convertToTime(Object timeObj, Calendar cal) { + private Time convertToTime(Object timeObj, Calendar cal) throws SQLException { if (timeObj == null) return null; - Timestamp ts; - if (timeObj instanceof Time) { - return (Time) timeObj; - } else if (timeObj instanceof Timestamp) { - ts = (Timestamp) timeObj; - } else { - // Parse the time object from string. If it can't be parsed - // as a Time object, then try to parse it as Timestamp. - try { - String timeStr = timeObj.toString(); - if (timeStr.contains("Z")) { - // TODO: fix getTime with a different time zone - DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ssX"); - LocalTime localTime = LocalTime.parse(timeStr, formatter); - return Time.valueOf(localTime); - } else if (timeStr.contains("+") || timeStr.contains("-")) { - LocalTime localTime = OffsetTime.parse(timeStr).toLocalTime(); - return Time.valueOf(localTime); - } else { - LocalTime localTime = LocalTime.parse(timeObj.toString()); - return Time.valueOf(localTime); - } - } catch (DateTimeParseException e) { - ts = parseIntoTimestamp(timeObj.toString(), cal); - } + if (timeObj instanceof Time) return (Time) timeObj; + if (timeObj instanceof Number) return new Time(((Number)timeObj).longValue()); // TODO: test + if (timeObj instanceof LocalTime) { + // Convert the LocalTime for the specified time zone into Time representing + // the same instant of time for the default time zone. + LocalTime localTime = (LocalTime)timeObj; + if (cal == null) return Time.valueOf(localTime); + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.now(), localTime); + ZonedDateTime originalZonedDateTime = localDateTime.atZone(cal.getTimeZone().toZoneId()); + ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); + return Time.valueOf(targetZonedDateTime.toLocalTime()); } - // use the timezone in the cal (if set) to indicate proper time for "1:00:00" - // e.g. 10:00:00 in EST is 07:00:00 in local time zone (PST) - if (cal == null) { - calendarWithUserTz.setTimeZone(defaultTimeZone); - } else { - calendarWithUserTz.setTimeZone(cal.getTimeZone()); + if (timeObj instanceof OffsetTime) { + OffsetTime localTime = ((OffsetTime)timeObj).withOffsetSameInstant(OffsetDateTime.now().getOffset()); + return Time.valueOf(localTime.toLocalTime()); } - calendarWithUserTz.setTimeInMillis(ts.getTime()); - return new Time(calendarWithUserTz.getTimeInMillis()); + + // Note: normally the user should properly store the Time object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Time already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Time. We try to do a + // best-effort string parsing into Time with standard format "HH:MM:SS". The user is then + // expected to handle parsing failure and implement custom logic to fetch this as String. + return Time.valueOf(timeObj.toString()); } @Override public Time getTime(final int columnIndex) throws SQLException { - return convertToTime(this.getObject(columnIndex), null); + return convertToTime(checkAndGetColumnValue(columnIndex), null); } private Timestamp convertToTimestamp(Object timestampObj, Calendar calendar) { if (timestampObj == null) return null; - if (timestampObj instanceof Timestamp) { - return (Timestamp) timestampObj; - } else if (timestampObj instanceof LocalDateTime) { - return convertLocalTimeToTimestamp((LocalDateTime) timestampObj, calendar); - } else { - // De-serialize it from string representation - return parseIntoTimestamp(timestampObj.toString(), calendar); + if (timestampObj instanceof Timestamp) return (Timestamp) timestampObj; + if (timestampObj instanceof Number) return new Timestamp(((Number)timestampObj).longValue()); + if (timestampObj instanceof LocalDateTime) { + // Convert LocalDateTime based on the specified calendar time zone info into a + // Timestamp based on the JVM's default time zone representing the same instant + long epochTimeInMillis; + LocalDateTime localTime = (LocalDateTime)timestampObj; + if (calendar != null) { + epochTimeInMillis = localTime.atZone(calendar.getTimeZone().toZoneId()).toInstant().toEpochMilli(); + } else { + epochTimeInMillis = localTime.atZone(defaultTimeZoneId).toInstant().toEpochMilli(); + } + return new Timestamp(epochTimeInMillis); + } + if (timestampObj instanceof OffsetDateTime) { + return Timestamp.from(((OffsetDateTime)timestampObj).toInstant()); } + if (timestampObj instanceof ZonedDateTime) { + return Timestamp.from(((ZonedDateTime)timestampObj).toInstant()); + } + + // Note: normally the user should properly store the Timestamp/DateTime object in the DB column and + // the underlying PG/MySQL/MariaDB driver would convert it into Timestamp already in getObject() + // prior to reaching this point in our caching logic. This is mainly to handle the case when the user + // stores a generic string in the DB column and wants to convert this into Timestamp. We try to do a + // best-effort string parsing into Timestamp with standard format "YYYY-MM-DD HH:MM:SS". The user is + // then expected to handle parsing failure and implement custom logic to fetch this as String. + return Timestamp.valueOf(timestampObj.toString()); } @Override public Timestamp getTimestamp(final int columnIndex) throws SQLException { - return convertToTimestamp(this.getObject(columnIndex), null); + return convertToTimestamp(checkAndGetColumnValue(columnIndex), null); } @Override @@ -393,84 +372,68 @@ public InputStream getBinaryStream(final int columnIndex) throws SQLException { @Override public String getString(final String columnLabel) throws SQLException { - Object value = this.getObject(columnLabel); - if (value == null) return null; - return value.toString(); + return getString(checkAndGetColumnIndex(columnLabel)); } @Override public boolean getBoolean(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) return false; - return Boolean.parseBoolean(value); + return getBoolean(checkAndGetColumnIndex(columnLabel)); } @Override public byte getByte(final String columnLabel) throws SQLException { - return (byte)this.getInt(columnLabel); + return getByte(checkAndGetColumnIndex(columnLabel)); } @Override public short getShort(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); - return Short.parseShort(value); + return getShort(checkAndGetColumnIndex(columnLabel)); } @Override public int getInt(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); - return Integer.parseInt(value); + return getInt(checkAndGetColumnIndex(columnLabel)); } @Override public long getLong(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); - return Long.parseLong(value); + return getLong(checkAndGetColumnIndex(columnLabel)); } @Override public float getFloat(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); - return Float.parseFloat(value); + return getFloat(checkAndGetColumnIndex(columnLabel)); } @Override public double getDouble(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) throw new SQLException("Column " + columnLabel + " doesn't exist"); - return Double.parseDouble(value); + return getDouble(checkAndGetColumnIndex(columnLabel)); } @Override @Deprecated public BigDecimal getBigDecimal(final String columnLabel, final int scale) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) return null; - return new BigDecimal(value).setScale(scale, RoundingMode.HALF_UP); + return getBigDecimal(checkAndGetColumnIndex(columnLabel), scale); } @Override public byte[] getBytes(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return getBytes(checkAndGetColumnIndex(columnLabel)); } @Override public Date getDate(final String columnLabel) throws SQLException { - return convertToDate(this.getObject(columnLabel), null); + return getDate(checkAndGetColumnIndex(columnLabel)); } @Override public Time getTime(final String columnLabel) throws SQLException { - return convertToTime(this.getObject(columnLabel), null); + return getTime(checkAndGetColumnIndex(columnLabel)); } @Override public Timestamp getTimestamp(final String columnLabel) throws SQLException { - return convertToTimestamp(this.getObject(columnLabel), null); + return getTimestamp(checkAndGetColumnIndex(columnLabel)); } @Override @@ -518,30 +481,38 @@ private void checkCurrentRow() throws SQLException { @Override public Object getObject(final int columnIndex) throws SQLException { checkCurrentRow(); - final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByIndex.containsKey(columnIndex)) { - throw new SQLException("The column index: " + columnIndex + " is out of range, number of columns: " + row.columnByIndex.size()); - } - Object obj = row.columnByIndex.get(columnIndex); - this.wasNullFlag = (obj == null); - return obj; + return checkAndGetColumnValue(columnIndex); } @Override public Object getObject(final String columnLabel) throws SQLException { checkCurrentRow(); + return checkAndGetColumnValue(checkAndGetColumnIndex(columnLabel)); + } + + // Check the column index passed in is proper, and return the value of the column from the current row + private Object checkAndGetColumnValue(final int columnIndex) throws SQLException { + if (columnIndex == 0 || columnIndex > this.columnNames.size()) throw new SQLException("Column out of bounds"); final CachedRow row = this.rows.get(this.currentRow); - if (!row.columnByName.containsKey(columnLabel)) { - throw new SQLException("The column label: " + columnLabel + " is not found"); - } - Object obj = row.columnByName.get(columnLabel); - this.wasNullFlag = (obj == null); - return obj; + final Object val = row.get(columnIndex); + this.wasNullFlag = (val == null); + return val; + } + + // Check column label exists and returns the column index corresponding to the column name + private int checkAndGetColumnIndex(final String columnLabel) throws SQLException { + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) throw new SQLException("Column not found: " + columnLabel); + return colIndex; } @Override public int findColumn(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + final Integer colIndex = columnNames.get(columnLabel); + if (colIndex == null) { + throw new SQLException("The column " + columnLabel + " is not found in this ResultSet."); + } + return colIndex; } @Override @@ -556,16 +527,16 @@ public Reader getCharacterStream(final String columnLabel) throws SQLException { @Override public BigDecimal getBigDecimal(final int columnIndex) throws SQLException { - String value = this.getString(columnIndex); - if (value == null) return null; - return new BigDecimal(value); + final Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof BigDecimal) return (BigDecimal) val; + if (val instanceof Number) return BigDecimal.valueOf(((Number) val).doubleValue()); + return new BigDecimal(val.toString()); } @Override public BigDecimal getBigDecimal(final String columnLabel) throws SQLException { - String value = this.getString(columnLabel); - if (value == null) return null; - return new BigDecimal(value); + return getBigDecimal(checkAndGetColumnIndex(columnLabel)); } @Override @@ -612,7 +583,10 @@ public boolean last() throws SQLException { @Override public int getRow() throws SQLException { - return this.currentRow + 1; + if (this.currentRow >= 0 && this.currentRow < this.rows.size()) { + return this.currentRow + 1; + } + return 0; } @Override @@ -694,17 +668,17 @@ public int getConcurrency() throws SQLException { @Override public boolean rowUpdated() throws SQLException { - throw new UnsupportedOperationException(); + return false; } @Override public boolean rowInserted() throws SQLException { - throw new UnsupportedOperationException(); + return false; } @Override public boolean rowDeleted() throws SQLException { - throw new UnsupportedOperationException(); + return false; } @Override @@ -995,42 +969,49 @@ public Array getArray(final String columnLabel) throws SQLException { @Override public Date getDate(final int columnIndex, final Calendar cal) throws SQLException { - return convertToDate(this.getObject(columnIndex), cal); + return convertToDate(checkAndGetColumnValue(columnIndex), cal); } @Override public Date getDate(final String columnLabel, final Calendar cal) throws SQLException { - return convertToDate(this.getObject(columnLabel), cal); + return getDate(checkAndGetColumnIndex(columnLabel), cal); } @Override public Time getTime(final int columnIndex, final Calendar cal) throws SQLException { - return convertToTime(this.getObject(columnIndex), null); + return convertToTime(checkAndGetColumnValue(columnIndex), cal); } @Override public Time getTime(final String columnLabel, final Calendar cal) throws SQLException { - return convertToTime(this.getObject(columnLabel), cal); + return getTime(checkAndGetColumnIndex(columnLabel), cal); } @Override public Timestamp getTimestamp(final int columnIndex, final Calendar cal) throws SQLException { - return convertToTimestamp(this.getObject(columnIndex), cal); + return convertToTimestamp(checkAndGetColumnValue(columnIndex), cal); } @Override public Timestamp getTimestamp(final String columnLabel, final Calendar cal) throws SQLException { - return convertToTimestamp(this.getObject(columnLabel), cal); + return getTimestamp(checkAndGetColumnIndex(columnLabel), cal); } @Override public URL getURL(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof URL) return (URL) val; + try { + return new URL(val.toString()); + } catch (MalformedURLException e) { + throw new SQLException("Cannot extract url: " + val, e); + } } @Override public URL getURL(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return getURL(checkAndGetColumnIndex(columnLabel)); } @Override @@ -1075,12 +1056,15 @@ public void updateArray(final String columnLabel, final Array x) throws SQLExcep @Override public RowId getRowId(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof RowId) return (RowId) val; + throw new SQLException("Cannot extract rowId: " + val); } @Override public RowId getRowId(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return getRowId(checkAndGetColumnIndex(columnLabel)); } @Override @@ -1100,7 +1084,7 @@ public int getHoldability() throws SQLException { @Override public boolean isClosed() throws SQLException { - return this.rows == null; + return closed; } @Override @@ -1160,12 +1144,12 @@ public void updateSQLXML(final String columnLabel, final SQLXML xmlObject) throw @Override public String getNString(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + return getString(columnIndex); } @Override public String getNString(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return getString(columnLabel); } @Override @@ -1344,7 +1328,11 @@ public T getObject(final String columnLabel, final Class type) throws SQL @Override public T unwrap(final Class iface) throws SQLException { - return iface == ResultSet.class ? iface.cast(this) : null; + if (iface.isAssignableFrom(this.getClass())) { + return iface.cast(this); + } else { + throw new SQLException("Cannot unwrap to " + iface.getName()); + } } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java index 8819e2737..bf295cb6b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSetMetaData.java @@ -1,139 +1,171 @@ package software.amazon.jdbc.plugin.cache; +import java.io.Serializable; import java.sql.ResultSetMetaData; import java.sql.SQLException; -public class CachedResultSetMetaData implements ResultSetMetaData { - - public static class Field { - // TODO: support binary format - private final String columnLabel; - private final String columnName; - public Field(String columnLabel, String columnName) { - this.columnLabel = columnLabel; - this.columnName = columnName; - } - - public String getColumnLabel() { - return columnLabel; - } - - public String getColumnName() { - return columnName; +class CachedResultSetMetaData implements ResultSetMetaData, Serializable { + protected final Field[] columns; + + protected static class Field implements Serializable { + String catalog; + String className; + String label; + String name; + String typeName; + int type; + int displaySize; + int precision; + String tableName; + int scale; + String schemaName; + boolean isAutoIncrement; + boolean isCaseSensitive; + boolean isCurrency; + boolean isDefinitelyWritable; + int isNullable; + boolean isReadOnly; + boolean isSearchable; + boolean isSigned; + boolean isWritable; + + protected Field(final ResultSetMetaData srcMetadata, int column) throws SQLException { + catalog = srcMetadata.getCatalogName(column); + className = srcMetadata.getColumnClassName(column); + label = srcMetadata.getColumnLabel(column); + name = srcMetadata.getColumnName(column); + typeName = srcMetadata.getColumnTypeName(column); + type = srcMetadata.getColumnType(column); + displaySize = srcMetadata.getColumnDisplaySize(column); + precision = srcMetadata.getPrecision(column); + tableName = srcMetadata.getTableName(column); + scale = srcMetadata.getScale(column); + schemaName = srcMetadata.getSchemaName(column); + isAutoIncrement = srcMetadata.isAutoIncrement(column); + isCaseSensitive = srcMetadata.isCaseSensitive(column); + isCurrency = srcMetadata.isCurrency(column); + isDefinitelyWritable = srcMetadata.isDefinitelyWritable(column); + isNullable = srcMetadata.isNullable(column); + isReadOnly = srcMetadata.isReadOnly(column); + isSearchable = srcMetadata.isSearchable(column); + isSigned = srcMetadata.isSigned(column); + isWritable = srcMetadata.isWritable(column); } } - protected Field[] fields; - - public CachedResultSetMetaData(Field[] fields) { - this.fields = fields; + CachedResultSetMetaData(Field[] columns) { + this.columns = columns; } @Override public int getColumnCount() throws SQLException { - return this.fields.length; + return columns.length; + } + + private Field getColumns(final int column) throws SQLException { + if (column == 0 || column > columns.length) + throw new SQLException("Wrong column number: " + column); + return columns[column - 1]; } @Override public boolean isAutoIncrement(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isAutoIncrement; } @Override public boolean isCaseSensitive(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isCaseSensitive; } @Override public boolean isSearchable(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isSearchable; } @Override public boolean isCurrency(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isCurrency; } @Override public int isNullable(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isNullable; } @Override public boolean isSigned(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).isSigned; } @Override public int getColumnDisplaySize(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).displaySize; } @Override public String getColumnLabel(int column) throws SQLException { - return fields[column-1].getColumnLabel(); + return getColumns(column).label; } @Override public String getColumnName(int column) throws SQLException { - return fields[column-1].getColumnName(); + return getColumns(column).name; } - // TODO @Override public String getSchemaName(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).schemaName; } @Override public int getPrecision(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).precision; } @Override public int getScale(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).scale; } - // TODO @Override public String getTableName(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).tableName; } @Override public String getCatalogName(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).catalog; } @Override public int getColumnType(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).type; } @Override public String getColumnTypeName(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).typeName; } @Override public boolean isReadOnly(int column) throws SQLException { - return true; + return getColumns(column).isReadOnly; } @Override public boolean isWritable(int column) throws SQLException { - return false; + return getColumns(column).isWritable; } @Override public boolean isDefinitelyWritable(int column) throws SQLException { - return false; + return getColumns(column).isDefinitelyWritable; } @Override public String getColumnClassName(int column) throws SQLException { - throw new UnsupportedOperationException(); + return getColumns(column).className; } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 9e709743e..59b67d40b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -16,7 +16,6 @@ package software.amazon.jdbc.plugin.cache; -import java.nio.charset.StandardCharsets; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.ResultSet; @@ -48,10 +47,6 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName))); - static { - PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); - } - private PluginService pluginService; private TelemetryFactory telemetryFactory; private TelemetryCounter hitCounter; @@ -63,8 +58,6 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties try { Class.forName("io.lettuce.core.RedisClient"); // Lettuce dependency Class.forName("org.apache.commons.pool2.impl.GenericObjectPool"); // Object pool dependency - Class.forName("com.fasterxml.jackson.databind.ObjectMapper"); // Jackson dependency - Class.forName("com.fasterxml.jackson.datatype.jsr310.JavaTimeModule"); // JSR310 dependency } catch (final ClassNotFoundException e) { throw new RuntimeException(Messages.get("DataRemoteCachePlugin.notInClassPath", new Object[] {e.getMessage()})); } @@ -111,11 +104,11 @@ private ResultSet fetchResultSetFromCache(String queryStr) { String cacheQueryKey = getCacheQueryKey(queryStr); if (cacheQueryKey == null) return null; // Treat this as a cache miss - byte[] result = cacheConnection.readFromCache(cacheQueryKey); - if (result == null) return null; + byte[] cachedResult = cacheConnection.readFromCache(cacheQueryKey); + if (cachedResult == null) return null; // Convert result into ResultSet try { - return CachedResultSet.deserializeFromJsonString(new String(result, StandardCharsets.UTF_8)); + return CachedResultSet.deserializeFromByteArray(cachedResult); } catch (Exception e) { LOGGER.warning("Error de-serializing cached result: " + e.getMessage()); return null; // Treat this as a cache miss @@ -132,8 +125,8 @@ private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) thro String cacheQueryKey = getCacheQueryKey(queryStr); if (cacheQueryKey == null) return rs; // Treat this condition as un-cacheable CachedResultSet crs = new CachedResultSet(rs); - String jsonValue = crs.serializeIntoJsonString(); - cacheConnection.writeToCache(cacheQueryKey, jsonValue.getBytes(StandardCharsets.UTF_8), expiry); + byte[] jsonString = crs.serializeIntoByteArray(); + cacheConnection.writeToCache(cacheQueryKey, jsonString, expiry); crs.beforeFirst(); return crs; } diff --git a/wrapper/src/test/build.gradle.kts b/wrapper/src/test/build.gradle.kts index 3ed1011f9..3e2f17251 100644 --- a/wrapper/src/test/build.gradle.kts +++ b/wrapper/src/test/build.gradle.kts @@ -55,7 +55,6 @@ dependencies { testImplementation("org.apache.poi:poi-ooxml:5.3.0") testImplementation("org.slf4j:slf4j-simple:2.0.13") testImplementation("com.fasterxml.jackson.core:jackson-databind:2.17.1") - testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.0") testImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") testImplementation("io.opentelemetry:opentelemetry-sdk:1.42.1") testImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.43.0") diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java index 61cfe478e..e22b7e13d 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -2,399 +2,444 @@ import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; +import java.sql.*; +import java.sql.Date; import java.time.*; import java.util.*; -import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Time; -import java.sql.Timestamp; public class CachedResultSetTest { - static List> testResultList = new ArrayList<>(); - static Calendar estCalendar = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + private CachedResultSet testResultSet; + @Mock ResultSet mockResultSet; + @Mock ResultSetMetaData mockResultSetMetadata; + private AutoCloseable closeable; + private static Calendar estCal = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + private TimeZone defaultTimeZone = TimeZone.getDefault(); + + // Column values: label, name, typeName, type, displaySize, precision, tableName, + // scale, schemaName, isAutoIncrement, isCaseSensitive, isCurrency, isDefinitelyWritable, + // isNullable, isReadOnly, isSearchable, isSigned, isWritable + private static final Object [][] testColumnMetadata = { + {"fieldNull", "fieldNull", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldInt", "fieldInt", "Integer", Types.INTEGER, 10, 2, "table", 1, "public", true, false, false, false, 0, false, true, true, true}, + {"fieldString", "fieldString", "String", Types.VARCHAR, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldBoolean", "fieldBoolean", "Boolean", Types.BOOLEAN, 10, 2, "table", 1, "public", false, false, false, false, 0, false, true, false, true}, + {"fieldByte", "fieldByte", "Byte", Types.TINYINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldShort", "fieldShort", "Short", Types.SMALLINT, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldLong", "fieldLong", "Long", Types.BIGINT, 10, 2, "table", 1, "public", false, false, false, false, 1, false, true, false, false}, + {"fieldFloat", "fieldFloat", "Float", Types.REAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDouble", "fieldDouble", "Double", Types.DOUBLE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldBigDecimal", "fieldBigDecimal", "BigDecimal", Types.DECIMAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldDate", "fieldDate", "Date", Types.DATE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldTime", "fieldTime", "Time", Types.TIME, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, + {"fieldDateTime", "fieldDateTime", "Timestamp", Types.TIMESTAMP, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false} + }; + + private static final Object [][] testColumnValues = { + {null, null}, + {1, 123456}, + {"John Doe", "Tony Stark"}, + {true, false}, + {(byte)100, (byte)70}, // Letter d and F in ASCII + {(short)55, (short)135}, + {2^33L, -2^35L}, + {3.14159f, -233.14159f}, + {2345.23345d, -2344355.4543d}, + {new BigDecimal("15.33"), new BigDecimal("-12.45")}, + {Date.valueOf("2025-03-15"), Date.valueOf("1102-01-15")}, + {Time.valueOf("22:54:00"), Time.valueOf("01:10:00")}, + {Timestamp.valueOf("2025-03-15 22:54:00"), Timestamp.valueOf("1950-01-18 21:50:05")} + }; + + private void mockGetMetadataFields(int column, int testMetadataCol) throws SQLException { + when(mockResultSetMetadata.getCatalogName(column)).thenReturn(""); + when(mockResultSetMetadata.getColumnClassName(column)).thenReturn("MyClass" + testMetadataCol); + when(mockResultSetMetadata.getColumnLabel(column)).thenReturn((String) testColumnMetadata[testMetadataCol][0]); + when(mockResultSetMetadata.getColumnName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][1]); + when(mockResultSetMetadata.getColumnTypeName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][2]); + when(mockResultSetMetadata.getColumnType(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][3]); + when(mockResultSetMetadata.getColumnDisplaySize(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][4]); + when(mockResultSetMetadata.getPrecision(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][5]); + when(mockResultSetMetadata.getTableName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][6]); + when(mockResultSetMetadata.getScale(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][7]); + when(mockResultSetMetadata.getSchemaName(column)).thenReturn((String) testColumnMetadata[testMetadataCol][8]); + when(mockResultSetMetadata.isAutoIncrement(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][9]); + when(mockResultSetMetadata.isCaseSensitive(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][10]); + when(mockResultSetMetadata.isCurrency(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][11]); + when(mockResultSetMetadata.isDefinitelyWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][12]); + when(mockResultSetMetadata.isNullable(column)).thenReturn((Integer) testColumnMetadata[testMetadataCol][13]); + when(mockResultSetMetadata.isReadOnly(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][14]); + when(mockResultSetMetadata.isSearchable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][15]); + when(mockResultSetMetadata.isSigned(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][16]); + when(mockResultSetMetadata.isWritable(column)).thenReturn((Boolean) testColumnMetadata[testMetadataCol][17]); + } + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")); + } + + @AfterEach + void cleanUp() { + TimeZone.setDefault(defaultTimeZone); + } - @BeforeAll - static void setUp() { - Map row = new HashMap<>(); - row.put("fieldNull", null); // null - row.put("fieldInt", 1); // Integer - row.put("fieldString", "John Doe"); // String - row.put("fieldBoolean", true); - row.put("fieldByte", (byte)100); // 100 in ASCII is letter d - row.put("fieldShort", (short)55); - row.put("fieldLong", 8589934592L); // 2^33 - row.put("fieldFloat", 3.14159f); - row.put("fieldDouble", 2345.23345d); - row.put("fieldBigDecimal", new BigDecimal("15.33")); - row.put("fieldDate", Date.valueOf("2025-03-15")); - row.put("fieldTime", Time.valueOf("22:54:00")); - row.put("fieldDateTime", Timestamp.valueOf("2025-03-15 22:54:00")); - testResultList.add(row); - Map row2 = new HashMap<>(); - row2.put("fieldNull", null); // null - row2.put("fieldInt", 123456); // Integer - row2.put("fieldString", "Tony Stark"); // String - row2.put("fieldBoolean", false); - row2.put("fieldByte", (byte)70); // 100 in ASCII is letter F - row2.put("fieldShort", (short)135); - row2.put("fieldLong", -34359738368L); // -2^35 - row2.put("fieldFloat", -233.14159f); - row2.put("fieldDouble", -2344355.4543d); - row2.put("fieldBigDecimal", new BigDecimal("-12.45")); - row2.put("fieldDate", Date.valueOf("1102-01-15")); - row2.put("fieldTime", Time.valueOf("01:10:00")); - row2.put("fieldDateTime", LocalDateTime.of(1981, 3, 10, 1, 10, 20)); - testResultList.add(row2); + void setUpDefaultTestResultSet() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(13); + for (int i = 0; i < 13; i++) { + mockGetMetadataFields(1+i, i); + when(mockResultSet.getObject(1+i)).thenReturn(testColumnValues[i][0], testColumnValues[i][1]); + } + when(mockResultSet.next()).thenReturn(true, true, false); + testResultSet = new CachedResultSet(mockResultSet); } - private void verifyRow1(ResultSet rs) throws SQLException { - Map colNameToIndexMap = new HashMap(); - ResultSetMetaData rsmd = rs.getMetaData(); - for (int i = 1; i <= rsmd.getColumnCount(); i++) { - colNameToIndexMap.put(rsmd.getColumnName(i), i); + private void verifyDefaultMetadata(ResultSet rs) throws SQLException { + ResultSetMetaData md = rs.getMetaData(); + for (int i = 0; i < md.getColumnCount(); i++) { + assertEquals("", md.getCatalogName(i+1)); + assertEquals("MyClass" + i, md.getColumnClassName(i+1)); + assertEquals(testColumnMetadata[i][0], md.getColumnLabel(i+1)); + assertEquals(testColumnMetadata[i][1], md.getColumnName(i+1)); + assertEquals(testColumnMetadata[i][2], md.getColumnTypeName(i+1)); + assertEquals(testColumnMetadata[i][3], md.getColumnType(i+1)); + assertEquals(testColumnMetadata[i][4], md.getColumnDisplaySize(i+1)); + assertEquals(testColumnMetadata[i][5], md.getPrecision(i+1)); + assertEquals(testColumnMetadata[i][6], md.getTableName(i+1)); + assertEquals(testColumnMetadata[i][7], md.getScale(i+1)); + assertEquals(testColumnMetadata[i][8], md.getSchemaName(i+1)); + assertEquals(testColumnMetadata[i][9], md.isAutoIncrement(i+1)); + assertEquals(testColumnMetadata[i][10], md.isCaseSensitive(i+1)); + assertEquals(testColumnMetadata[i][11], md.isCurrency(i+1)); + assertEquals(testColumnMetadata[i][12], md.isDefinitelyWritable(i+1)); + assertEquals(testColumnMetadata[i][13], md.isNullable(i+1)); + assertEquals(testColumnMetadata[i][14], md.isReadOnly(i+1)); + assertEquals(testColumnMetadata[i][15], md.isSearchable(i+1)); + assertEquals(testColumnMetadata[i][16], md.isSigned(i+1)); + assertEquals(testColumnMetadata[i][17], md.isWritable(i+1)); } - assertEquals(1, rs.getInt(colNameToIndexMap.get("fieldInt"))); + } + + private void verifyDefaultRow(ResultSet rs, int row) throws SQLException { assertFalse(rs.wasNull()); - assertEquals("John Doe", rs.getString(colNameToIndexMap.get("fieldString"))); + assertNull(rs.getObject(1)); // fieldNull + assertEquals(1, rs.findColumn("fieldNull")); + assertTrue(rs.wasNull()); + assertEquals((int) testColumnValues[1][row], rs.getInt(2)); // fieldInt assertFalse(rs.wasNull()); - assertTrue(rs.getBoolean(colNameToIndexMap.get("fieldBoolean"))); + assertEquals((int) testColumnValues[1][row], rs.getInt("fieldInt")); + assertEquals(2, rs.findColumn("fieldInt")); assertFalse(rs.wasNull()); - assertEquals(100, rs.getByte(colNameToIndexMap.get("fieldByte"))); + assertEquals(testColumnValues[2][row], rs.getString(3)); // fieldString assertFalse(rs.wasNull()); - assertEquals(55, rs.getShort(colNameToIndexMap.get("fieldShort"))); + assertEquals(testColumnValues[2][row], rs.getString("fieldString")); + assertEquals(3, rs.findColumn("fieldString")); assertFalse(rs.wasNull()); - assertNull(rs.getObject(colNameToIndexMap.get("fieldNull"))); - assertTrue(rs.wasNull()); - assertEquals(8589934592L, rs.getLong(colNameToIndexMap.get("fieldLong"))); + assertEquals(testColumnValues[3][row], rs.getBoolean(4)); // fieldBoolean + assertFalse(rs.wasNull()); + assertEquals(testColumnValues[3][row], rs.getBoolean("fieldBoolean")); + assertEquals(4, rs.findColumn("fieldBoolean")); assertFalse(rs.wasNull()); - assertEquals(3.14159f, rs.getFloat(colNameToIndexMap.get("fieldFloat")), 0); + assertEquals((byte) testColumnValues[4][row], rs.getByte(5)); // fieldByte assertFalse(rs.wasNull()); - assertEquals(2345.23345d, rs.getDouble(colNameToIndexMap.get("fieldDouble"))); + assertEquals((byte) testColumnValues[4][row], rs.getByte("fieldByte")); + assertEquals(5, rs.findColumn("fieldByte")); assertFalse(rs.wasNull()); - assertEquals(0, rs.getBigDecimal(colNameToIndexMap.get("fieldBigDecimal")).compareTo(new BigDecimal("15.33"))); + assertEquals((short) testColumnValues[5][row], rs.getShort(6)); // fieldShort assertFalse(rs.wasNull()); - assertNull(rs.getObject(colNameToIndexMap.get("fieldNull"))); + assertEquals((short) testColumnValues[5][row], rs.getShort("fieldShort")); + assertEquals(6, rs.findColumn("fieldShort")); + assertFalse(rs.wasNull()); + assertNull(rs.getObject("fieldNull")); assertTrue(rs.wasNull()); - Date date = rs.getDate(colNameToIndexMap.get("fieldDate")); - assertEquals(1742022000000L, date.getTime()); + assertEquals((Long) testColumnValues[6][row], rs.getLong(7)); // fieldLong assertFalse(rs.wasNull()); - Time time = rs.getTime(colNameToIndexMap.get("fieldTime")); - assertEquals(111240000, time.getTime()); + assertEquals((Long) testColumnValues[6][row], rs.getLong("fieldLong")); + assertEquals(7, rs.findColumn("fieldLong")); assertFalse(rs.wasNull()); - Timestamp ts = rs.getTimestamp(colNameToIndexMap.get("fieldDateTime")); - assertEquals(1742104440000L, ts.getTime()); + assertEquals((float) testColumnValues[7][row], rs.getFloat(8), 0); // fieldFloat assertFalse(rs.wasNull()); - } - - private void verifyRow2(ResultSet rs) throws SQLException { - assertEquals(123456, rs.getInt("fieldInt")); + assertEquals((float) testColumnValues[7][row], rs.getFloat("fieldFloat"), 0); + assertEquals(8, rs.findColumn("fieldFloat")); assertFalse(rs.wasNull()); - assertEquals("Tony Stark", rs.getString("fieldString")); + assertEquals((double) testColumnValues[8][row], rs.getDouble(9)); // fieldDouble assertFalse(rs.wasNull()); - assertFalse(rs.getBoolean("fieldBoolean")); + assertEquals((double) testColumnValues[8][row], rs.getDouble("fieldDouble")); + assertEquals(9, rs.findColumn("fieldDouble")); assertFalse(rs.wasNull()); - assertEquals(70, rs.getByte("fieldByte")); + assertEquals(0, rs.getBigDecimal(10).compareTo((BigDecimal) testColumnValues[9][row])); // fieldBigDecimal assertFalse(rs.wasNull()); - assertEquals(135, rs.getShort("fieldShort")); + assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo((BigDecimal) testColumnValues[9][row])); + assertEquals(10, rs.findColumn("fieldBigDecimal")); assertFalse(rs.wasNull()); - assertNull(rs.getObject("fieldNull")); + assertNull(rs.getObject(1)); // fieldNull assertTrue(rs.wasNull()); - assertEquals(-34359738368L, rs.getLong("fieldLong")); - assertFalse(rs.wasNull()); - assertEquals(-233.14159f, rs.getFloat("fieldFloat")); + assertEquals(testColumnValues[10][row], rs.getDate(11)); // fieldDate assertFalse(rs.wasNull()); - assertEquals(-2344355.4543d, rs.getDouble("fieldDouble")); + assertEquals(testColumnValues[10][row], rs.getDate("fieldDate")); + assertEquals(11, rs.findColumn("fieldDate")); assertFalse(rs.wasNull()); - assertEquals(0, rs.getBigDecimal("fieldBigDecimal").compareTo(new BigDecimal("-12.45"))); + assertEquals(testColumnValues[11][row], rs.getTime(12)); // fieldTime assertFalse(rs.wasNull()); - Date date = rs.getDate("fieldDate"); - assertEquals("1102-01-15", date.toString()); + assertEquals(testColumnValues[11][row], rs.getTime("fieldTime")); + assertEquals(12, rs.findColumn("fieldTime")); assertFalse(rs.wasNull()); - Time time = rs.getTime("fieldTime"); - assertEquals("01:10:00", time.toString()); + assertEquals(testColumnValues[12][row], rs.getTimestamp(13)); // fieldDateTime assertFalse(rs.wasNull()); - Timestamp ts = rs.getTimestamp("fieldDateTime"); - assertTrue(ts.toString().startsWith("1981-03-10 01:10:20")); + assertEquals(testColumnValues[12][row], rs.getTimestamp("fieldDateTime")); + assertEquals(13, rs.findColumn("fieldDateTime")); assertFalse(rs.wasNull()); + verifyNonexistingField(rs); } - @Test - void test_create_and_verify_basic() throws Exception { - // An empty result set - ResultSet rs0 = new CachedResultSet(new ArrayList<>()); - assertFalse(rs0.next()); - ResultSetMetaData md = rs0.getMetaData(); - assertEquals(0, md.getColumnCount()); - // Result set containing data - ResultSet rs = new CachedResultSet(testResultList); - verifyMetadata(rs); - verifyContent(rs); - rs.beforeFirst(); - CachedResultSet cachedRs = new CachedResultSet(rs); - verifyMetadata(cachedRs); - verifyContent(cachedRs); - rs.clearWarnings(); - assertNull(rs.getWarnings()); + private void verifyNonexistingField(ResultSet rs) { + try { + rs.getObject("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } + try { + rs.findColumn("nonExistingField"); + throw new IllegalStateException("Expected an exception due to column doesn't exist"); + } catch (SQLException e) { + // Expected an exception if the column doesn't exist + } } @Test - void test_serialize_and_deserialize_basic() throws Exception { - CachedResultSet cachedRs = new CachedResultSet(testResultList); - String serialized_data = cachedRs.serializeIntoJsonString(); - ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); - verifyContent(rs); - } - - private void verifyContent(ResultSet rs) throws SQLException { + void test_basic_cached_result_set() throws Exception { + // Basic verification of the test result set + setUpDefaultTestResultSet(); + verifyDefaultMetadata(testResultSet); + assertEquals(0, testResultSet.getRow()); + assertTrue(testResultSet.next()); + assertEquals(1, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 0); + assertTrue(testResultSet.next()); + assertEquals(2, testResultSet.getRow()); + verifyDefaultRow(testResultSet, 1); + assertFalse(testResultSet.next()); + assertEquals(0, testResultSet.getRow()); + assertNull(testResultSet.getWarnings()); + testResultSet.clearWarnings(); + assertNull(testResultSet.getWarnings()); + testResultSet.beforeFirst(); + // Test serialization and de-serialization of the result set + byte[] serialized_data = testResultSet.serializeIntoByteArray(); + ResultSet rs = CachedResultSet.deserializeFromByteArray(serialized_data); + verifyDefaultMetadata(rs); assertTrue(rs.next()); - if (rs.getInt("fieldInt") == 1) { - verifyRow1(rs); - assertTrue(rs.next()); - verifyRow2(rs); - rs.previous(); - verifyRow1(rs); - verifyNonexistingField(rs); - rs.relative(1); // Advances to next row - verifyRow2(rs); - rs.absolute(2); - verifyRow2(rs); - } else { - verifyRow2(rs); - assertTrue(rs.next()); - verifyRow1(rs); - rs.previous(); - verifyRow2(rs); - verifyNonexistingField(rs); - rs.relative(1); // Advances to next row - verifyRow1(rs); - rs.absolute(2); - verifyRow1(rs); - } + verifyDefaultRow(rs, 0); + assertTrue(rs.next()); + verifyDefaultRow(rs, 1); assertFalse(rs.next()); - rs.relative(-10); + assertNull(rs.getWarnings()); + rs.relative(-10); // We should be before the start of the rows assertTrue(rs.isBeforeFirst()); - rs.relative(10); + assertEquals(0, rs.getRow()); + rs.relative(10); // We should be after the end of the rows assertTrue(rs.isAfterLast()); - rs.absolute(-10); + assertEquals(0, rs.getRow()); + rs.absolute(-10); // We should be before the start of the rows assertTrue(rs.isBeforeFirst()); - rs.absolute(10); + assertFalse(rs.absolute(100)); // Jump to after the end of the rows assertTrue(rs.isAfterLast()); - } - - private void verifyMetadata(ResultSet rs) throws SQLException { - ResultSetMetaData md = rs.getMetaData(); - List expectedCols = Arrays.asList("fieldNull", "fieldInt", "fieldString", "fieldBoolean", "fieldByte", "fieldShort", "fieldLong", "fieldFloat", "fieldDouble", "fieldBigDecimal", "fieldDate", "fieldTime", "fieldDateTime"); - assertEquals(md.getColumnCount(), testResultList.get(0).size()); - List actualColNames = new ArrayList<>(); - List actualColLabels = new ArrayList<>(); - for (int i = 1; i <= md.getColumnCount(); i++) { - actualColNames.add(md.getColumnName(i)); - actualColLabels.add(md.getColumnLabel(i)); - } - assertTrue(actualColNames.containsAll(expectedCols)); - assertTrue(expectedCols.containsAll(actualColNames)); - assertTrue(actualColLabels.containsAll(expectedCols)); - assertTrue(expectedCols.containsAll(actualColLabels)); + assertEquals(0, rs.getRow()); + assertFalse(rs.absolute(0)); // Go to the beginning of rows + assertTrue(rs.isBeforeFirst()); + assertTrue(rs.next()); // We are at first row + verifyDefaultRow(rs, 0); + rs.relative(1); // Advances to next row + verifyDefaultRow(rs, 1); + assertTrue(rs.previous()); // Go back to first row + verifyDefaultRow(rs, 0); + assertFalse(rs.previous()); + assertTrue(rs.absolute(2)); // Jump to second row + verifyDefaultRow(rs, 1); + assertTrue(rs.first()); // go to first row + verifyDefaultRow(rs, 0); + assertEquals(1, rs.getRow()); + assertTrue(rs.last()); // go to last row + verifyDefaultRow(rs, 1); + assertEquals(2, rs.getRow()); } @Test - void test_get_timestamp() throws SQLException { - // Timestamp string that is in ISO format with time zone information in UTC - Map row = new HashMap<>(); - row.put("fieldTimestamp0", "2025-06-03T11:59:21.822364Z"); - // Timestamp string that is in ISO format with time zone information as offset - row.put("fieldTimestamp1", "2024-02-13T07:40:30.822364-05:00"); - row.put("fieldTimestamp2", "2023-10-27T10:00:00+02:00"); - // Timestamp string doesn't contain time zone information. - row.put("fieldTimestamp3", "1760-06-03T11:59:21.822364"); - row.put("fieldTimestamp4", "2020-05-04 10:06:10.822364"); - row.put("fieldTimestamp5", "2015-09-01 23:33:00"); - // Timestamp string doesn't contain time zone or HH:MM:SS information - row.put("fieldTimestamp6", "2019-03-15"); - row.put("fieldTimestamp7", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); - row.put("fieldTimestamp8", LocalDateTime.parse("2025-04-01T21:55:21.822364")); - List> testTimestamps = Collections.singletonList(row); - CachedResultSet cachedRs = new CachedResultSet(testTimestamps); - assertTrue(cachedRs.next()); - verifyTimestamps(cachedRs); - verifyNonexistingField(cachedRs); - cachedRs.beforeFirst(); - String serialized_data = cachedRs.serializeIntoJsonString(); - ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); + void test_get_special_bigDecimal() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 9); + when(mockResultSet.getObject(1)).thenReturn( + 12450.567, + -132.45, + "142.346", + "invalid", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + CachedResultSet rs = new CachedResultSet(mockResultSet); + assertTrue(rs.next()); - verifyTimestamps(rs); - verifyNonexistingField(rs); - } + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("12450.567"))); - private void verifyNonexistingField(ResultSet rs) { + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("-132.45"))); + assertTrue(rs.next()); + assertEquals(0, rs.getBigDecimal(1).compareTo(new BigDecimal("142.346"))); + assertTrue(rs.next()); try { - rs.getTimestamp("nonExistingField"); - throw new IllegalStateException("Expected an exception due to column doesn't exist"); - } catch (SQLException e) { - // Expected an exception if the column doesn't exist + rs.getBigDecimal(1); + fail("Invalid value should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass } - } - - private void verifyTimestamps(ResultSet rs) throws SQLException { - // Verifying the timestamp with time zone information. The specified calendar doesn't matter. - Timestamp expectedTs = Timestamp.from(Instant.parse("2025-06-03T11:59:21.822364Z")); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp0").getTime()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp0", estCalendar).getTime()); - - expectedTs = Timestamp.from(OffsetDateTime.parse("2024-02-13T07:40:30.822364-05:00").toInstant()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp1").getTime()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp1", estCalendar).getTime()); - - expectedTs = Timestamp.from(OffsetDateTime.parse("2023-10-27T10:00:00+02:00").toInstant()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp2").getTime()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp2", estCalendar).getTime()); - - // Verify timestamp without time zone information. The specified calendar matters here - LocalDateTime localTime = LocalDateTime.parse("1760-06-03T11:59:21.822364"); - ZoneId estZone = ZoneId.of("America/New_York"); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp3").getTime()); - assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp3", estCalendar).getTime()); - - localTime = LocalDateTime.parse("2020-05-04T10:06:10.822364"); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp4").getTime()); - assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp4", estCalendar).getTime()); - - localTime = LocalDateTime.parse("2015-09-01T23:33:00"); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp5").getTime()); - assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp5", estCalendar).getTime()); - - localTime = LocalDateTime.parse("2019-03-15T00:00:00"); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp6").getTime()); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp6", estCalendar).getTime()); - - expectedTs = Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z")); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp7").getTime()); - assertEquals(expectedTs.getTime(), rs.getTimestamp("fieldTimestamp7", estCalendar).getTime()); - - localTime = LocalDateTime.parse("2025-04-01T21:55:21.822364"); - assertEquals(localTime.atZone(ZoneId.systemDefault()).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8").getTime()); - assertEquals(localTime.atZone(estZone).toInstant().toEpochMilli(), rs.getTimestamp("fieldTimestamp8", estCalendar).getTime()); + // Value is null + assertTrue(rs.next()); + assertNull(rs.getBigDecimal(1)); } @Test - void test_parse_time() throws SQLException { - // Timestamp string that is in ISO format with time zone information in UTC - Map row = new HashMap<>(); - row.put("fieldTime0", Time.valueOf("18:45:20")); - row.put("fieldTime1", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); - row.put("fieldTime2", "10:30:00"); - row.put("fieldTime3", "11:59:21.822364"); - // Timestamp string that is in ISO format with time zone information - row.put("fieldTime4", "10:00:00Z"); - row.put("fieldTime5", "05:30:00-02:00"); - row.put("fieldTime6", "08:25:10+02:00"); - // Timestamp string doesn't contain time zone information. - row.put("fieldTime7", "2025-06-03T11:59:21.822364"); - row.put("fieldTime8", "1901-05-04 10:06:10.822364"); - row.put("fieldTime9", "2015-09-01 23:33:00"); - // Timestamp string doesn't contain time zone or HH:MM:SS information - row.put("fieldTime10", "2019-03-15"); - List> testTimes = Collections.singletonList(row); - CachedResultSet cachedRs = new CachedResultSet(testTimes); + void test_get_special_timestamp() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 12); + when(mockResultSet.getObject(1)).thenReturn( + 1504844311000L, + LocalDateTime.of(1981, 3, 10, 1, 10, 20), + OffsetDateTime.parse("2025-08-10T10:00:00+03:00"), + ZonedDateTime.parse("2024-07-30T10:00:00+02:00[Europe/Berlin]"), + "2015-03-15 12:50:04", + "invalidDateTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Timestamp from a number assertTrue(cachedRs.next()); - verifyTimes(cachedRs); - verifyNonexistingField(cachedRs); - cachedRs.beforeFirst(); - String serialized_data = cachedRs.serializeIntoJsonString(); - ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); - assertTrue(rs.next()); - verifyTimes(rs); - } - - private void verifyTimes(ResultSet rs) throws SQLException { - // Verifying the timestamp with time zone information. The specified calendar doesn't matter. - assertEquals("18:45:20", rs.getTime("fieldTime0").toString()); - assertEquals("18:45:20", rs.getTime("fieldTime0", estCalendar).toString()); - - // Convert from timestamp with time zone info - assertEquals("03:30:20", rs.getTime("fieldTime1").toString()); - assertEquals("03:30:20", rs.getTime("fieldTime1", estCalendar).toString()); - - // Verify timestamp without time zone information. The specified calendar matters here - assertEquals("10:30:00", rs.getTime("fieldTime2").toString()); - assertEquals("10:30:00", rs.getTime("fieldTime2", estCalendar).toString()); // Should be 07:30:00 - - assertEquals("11:59:21", rs.getTime("fieldTime3").toString()); - assertEquals("11:59:21", rs.getTime("fieldTime3", estCalendar).toString()); - - assertEquals("10:00:00", rs.getTime("fieldTime4").toString()); - assertEquals("10:00:00", rs.getTime("fieldTime4", estCalendar).toString()); - - assertEquals("05:30:00", rs.getTime("fieldTime5").toString()); - assertEquals("05:30:00", rs.getTime("fieldTime5", estCalendar).toString()); - - assertEquals("08:25:10", rs.getTime("fieldTime6").toString()); - assertEquals("08:25:10", rs.getTime("fieldTime6", estCalendar).toString()); - - assertEquals("11:59:21", rs.getTime("fieldTime7").toString()); - assertEquals("08:59:21", rs.getTime("fieldTime7", estCalendar).toString()); - - assertEquals("10:06:10", rs.getTime("fieldTime8").toString()); - assertEquals("07:06:10", rs.getTime("fieldTime8", estCalendar).toString()); - - assertEquals("23:33:00", rs.getTime("fieldTime9").toString()); - assertEquals("20:33:00", rs.getTime("fieldTime9", estCalendar).toString()); - - assertEquals("00:00:00", rs.getTime("fieldTime10").toString()); - assertEquals("00:00:00", rs.getTime("fieldTime10", estCalendar).toString()); + assertEquals(new Timestamp(1504844311000L), cachedRs.getTimestamp(1)); + // Timestamp from LocalDateTime + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("1981-03-10 01:10:20"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("1981-03-09 22:10:20"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from OffsetDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2025-08-10 00:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestmap from ZonedDateTime (containing time zone info) + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2024-07-30 01:00:00"), cachedRs.getTimestamp(1, estCal)); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1)); + assertEquals(Timestamp.valueOf("2015-03-15 12:50:04"), cachedRs.getTimestamp(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTimestamp(1); + fail("Invalid timestamp should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Timestamp is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTimestamp(1)); } @Test - void test_parse_date() throws SQLException { - Map row = new HashMap<>(); - row.put("fieldDate0", Date.valueOf("2009-09-30")); - row.put("fieldDate1", Timestamp.from(Instant.parse("2024-08-01T10:30:20.822364Z"))); - row.put("fieldDate2", "2012-10-01"); - row.put("fieldDate3", "1930-03-20T05:30:20.822364Z"); - // Timestamp string doesn't contain time zone information. - row.put("fieldDate4", "2025-06-03T11:59:21.822364"); - row.put("fieldDate5", "1901-05-04 10:06:10.822364"); - row.put("fieldDate6", "2015-09-01 23:33:00"); - // Timestamp string doesn't contain time zone or HH:MM:SS information - List> testTimes = Collections.singletonList(row); - CachedResultSet cachedRs = new CachedResultSet(testTimes); + void test_get_special_time() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 11); + when(mockResultSet.getObject(1)).thenReturn( + 4362000L, + LocalTime.of(10, 20, 30), + OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC), + "15:34:20", + "InvalidTime", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Time from a number assertTrue(cachedRs.next()); - verifyDates(cachedRs); - verifyNonexistingField(cachedRs); - cachedRs.beforeFirst(); - String serialized_data = cachedRs.serializeIntoJsonString(); - ResultSet rs = CachedResultSet.deserializeFromJsonString(serialized_data); - assertTrue(rs.next()); - verifyDates(rs); + assertEquals(new Time(4362000L), cachedRs.getTime(1)); + // Time from LocalTime + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("10:20:30"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("07:20:30"), cachedRs.getTime(1, estCal)); + // Time from OffsetTime + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1, estCal)); + // Timestamp from String + assertTrue(cachedRs.next()); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1)); + assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getTime(1); + fail("Invalid time should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Time is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getTime(1)); } - private void verifyDates(ResultSet rs) throws SQLException { - assertEquals("2009-09-30", rs.getDate("fieldDate0").toString()); - assertEquals("2009-09-30", rs.getDate("fieldDate0", estCalendar).toString()); - - assertEquals("2024-08-01", rs.getDate("fieldDate1").toString()); - assertEquals("2024-08-01", rs.getDate("fieldDate1", estCalendar).toString()); - - assertEquals("2012-10-01", rs.getDate("fieldDate2").toString()); - assertEquals("2012-10-01", rs.getDate("fieldDate2", estCalendar).toString()); - - assertEquals("1930-03-19", rs.getDate("fieldDate3").toString()); - assertEquals("1930-03-19", rs.getDate("fieldDate3", estCalendar).toString()); - - assertEquals("2025-06-03", rs.getDate("fieldDate4").toString()); - assertEquals("2025-06-03", rs.getDate("fieldDate4", estCalendar).toString()); - - assertEquals("1901-05-04", rs.getDate("fieldDate5").toString()); - assertEquals("1901-05-04", rs.getDate("fieldDate5", estCalendar).toString()); - - assertEquals("2015-09-01", rs.getDate("fieldDate6").toString()); - assertEquals("2015-09-01", rs.getDate("fieldDate6", estCalendar).toString()); + @Test + void test_get_special_date() throws SQLException { + // Create the default CachedResultSet for testing + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 10); + when(mockResultSet.getObject(1)).thenReturn( + 1515944311000L, + -1000000000L, + LocalDate.of(2010, 10, 30), + "2025-03-15", + "InvalidDate", + null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Date from a number + assertTrue(cachedRs.next()); + Date date = cachedRs.getDate(1); + assertEquals(new Date(1515944311000L), date); + assertTrue(cachedRs.next()); + assertEquals(new Date(-1000000000L), cachedRs.getDate(1)); + // Date from LocalDate + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2010-10-30"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2010-10-29"), cachedRs.getDate(1, estCal)); + // Date from String + assertTrue(cachedRs.next()); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1)); + assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1, estCal)); + assertTrue(cachedRs.next()); + try { + cachedRs.getDate(1); + fail("Invalid date should cause a test failure"); + } catch (IllegalArgumentException e) { + // pass + } + // Date is null + assertTrue(cachedRs.next()); + assertNull(cachedRs.getDate(1)); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 1961df445..8de187932 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -4,13 +4,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import java.nio.charset.StandardCharsets; import java.sql.*; import java.util.Properties; import org.apache.commons.lang3.RandomStringUtils; @@ -55,7 +54,7 @@ void setUp() throws SQLException { when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter); when(mockResult1.getMetaData()).thenReturn(mockMetaData); when(mockMetaData.getColumnCount()).thenReturn(1); - when(mockMetaData.getColumnName(1)).thenReturn("fooName"); + when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); plugin = new DataRemoteCachePlugin(mockPluginService, props); plugin.setCacheConnection(mockCacheConn); } @@ -139,7 +138,7 @@ void test_execute_noCachingLongQuery() throws Exception { } @Test - void test_execute_cachingMiss() throws Exception { + void test_execute_cachingMissAndHit() throws Exception { // Query is not cacheable when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(false); @@ -160,49 +159,22 @@ void test_execute_cachingMiss() throws Exception { assertTrue(rs.next()); assertEquals("bar1", rs.getString("fooName")); assertFalse(rs.next()); - verify(mockPluginService, times(2)).getCurrentConnection(); - verify(mockPluginService).isInTransaction(); - verify(mockCacheConn).readFromCache("public_user_select * from A"); + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /* CacheTtl=50s */select * from A"}); + + assertTrue(rs2.next()); + assertEquals("bar1", rs2.getString("fooName")); + assertFalse(rs2.next()); + verify(mockPluginService, times(3)).getCurrentConnection(); + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCacheConn, times(2)).readFromCache("public_user_select * from A"); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache("public_user_select * from A", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 100); - verify(mockTotalCallsCounter).inc(); + verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(100)); + verify(mockTotalCallsCounter, times(2)).inc(); verify(mockMissCounter).inc(); - } - - @Test - void test_execute_cachingHit() throws Exception { - final String cachedResult = "[{\"date\":\"2009-09-30\",\"code\":\"avata\"},{\"date\":\"2015-05-30\",\"code\":\"dracu\"}]"; - - // Query is cacheable - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.isInTransaction()).thenReturn(false); - when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); - when(mockCacheConn.readFromCache("public_user_select * from table")).thenReturn(cachedResult.getBytes()); - when(mockCallable.call()).thenReturn(mockResult1); - - // Result set contains 1 row - when(mockResult1.next()).thenReturn(true, false); - when(mockResult1.getObject(1)).thenReturn("bar1"); - - ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{" /* CacheTtl=50s */select * from table"}); - - // Cached result set contains 2 rows - assertTrue(rs.next()); - assertEquals("2009-09-30", rs.getString("date")); - assertEquals("avata", rs.getString("code")); - assertTrue(rs.next()); - assertEquals("2015-05-30", rs.getString("date")); - assertEquals("dracu", rs.getString("code")); - assertFalse(rs.next()); - verify(mockPluginService).getCurrentConnection(); - verify(mockPluginService).isInTransaction(); - verify(mockCacheConn).readFromCache("public_user_select * from table"); - verify(mockCallable, never()).call(); - verify(mockCacheConn, never()).writeToCache("public_user_select * from table", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 50); - verify(mockTotalCallsCounter).inc(); verify(mockHitCounter).inc(); } @@ -231,7 +203,7 @@ void test_transaction_cacheQuery() throws Exception { verify(mockPluginService).isInTransaction(); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache("public_user_select * from T", "[{\"fooName\":\"bar1\"}]".getBytes(StandardCharsets.UTF_8), 300); + verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); verify(mockTotalCallsCounter, never()).inc(); verify(mockHitCounter, never()).inc(); verify(mockMissCounter, never()).inc(); From 94cbbeb61f7d5e8e28fcfe8e5d15920cc38908c8 Mon Sep 17 00:00:00 2001 From: Shaopeng Gu Date: Fri, 15 Aug 2025 13:41:49 -0700 Subject: [PATCH 12/24] Implemented query hint feature that supports multiple query parameters (pull request #1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace /* cacheTTL=300s */ with /*+ CACHE_PARAM(ttl=300s) */ format - Add case-insensitive CACHE_PARAM parsing with flexible placement - Implement malformed hint detection with JdbcCacheMalformedQueryHint telemetry - Add comprehensive test coverage for valid/invalid/malformed cases - Support multiple parameters for future extensibility Fixes: Query hint parsing to follow Oracle SQL hint standards feat: add TTL validation for cache query hints - Reject TTL values <= 0 as malformed hints - Allow large TTL values (no upper limit) - Increment malformed hint counter for invalid TTL values - Add comprehensive test coverage for TTL edge cases Validates: ttl=0s, ttl=-10s → not cacheable Allows: ttl=999999s → cache with large TTL updated test cases to reflect new implementation --- .../plugin/cache/DataRemoteCachePlugin.java | 78 ++++++--- .../cache/DataRemoteCachePluginTest.java | 158 ++++++++++++++---- 2 files changed, 185 insertions(+), 51 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 59b67d40b..92e9dc400 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -29,7 +29,6 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; -import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; @@ -39,6 +38,9 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); + private static final String QUERY_HINT_START_PATTERN = "/*+"; + private static final String QUERY_HINT_END_PATTERN = "*/"; + private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, JdbcMethod.STATEMENT_EXECUTE.methodName, @@ -52,6 +54,7 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private TelemetryCounter hitCounter; private TelemetryCounter missCounter; private TelemetryCounter totalCallsCounter; + private TelemetryCounter malformedHintCounter; private CacheConnection cacheConnection; public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { @@ -66,6 +69,7 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties this.hitCounter = telemetryFactory.createCounter("remoteCache.cache.hit"); this.missCounter = telemetryFactory.createCounter("remoteCache.cache.miss"); this.totalCallsCounter = telemetryFactory.createCounter("remoteCache.cache.totalCalls"); + this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint"); this.cacheConnection = new CacheConnection(properties); } @@ -133,33 +137,68 @@ private ResultSet cacheResultSet(String queryStr, ResultSet rs, int expiry) thro /** * Determine the TTL based on an input query - * @param queryHint string. e.g. "NO CACHE", or "cacheTTL=100s" + * @param queryHint string. e.g. "CACHE_PARAM(ttl=100s, key=custom)" * @return TTL in seconds to cache the query. * null if the query is not cacheable. */ protected Integer getTtlForQuery(String queryHint) { // Empty query is not cacheable if (StringUtils.isNullOrEmpty(queryHint)) return null; - // Query longer than 16K is not cacheable - String[] tokens = queryHint.toLowerCase().split("cache"); - if (tokens.length >= 2) { - // Handle "no cache". - if (!StringUtils.isNullOrEmpty(tokens[0]) && "no".equals(tokens[0])) return null; - // Handle "cacheTTL=Xs" - if (!StringUtils.isNullOrEmpty(tokens[1]) && tokens[1].startsWith("ttl=")) { - int endIndex = tokens[1].indexOf('s'); - if (endIndex > 0) { + // Find CACHE_PARAM anywhere in the hint string (case insensitive) + String upperHint = queryHint.toUpperCase(); + int cacheParamStart = upperHint.indexOf(CACHE_PARAM_PATTERN); + if (cacheParamStart == -1) return null; + + // Find the matching closing parenthesis + int paramsStart = cacheParamStart + CACHE_PARAM_PATTERN.length(); + int paramsEnd = upperHint.indexOf(")", paramsStart); + if (paramsEnd == -1) return null; + + // Extract parameters between parentheses + String cacheParams = upperHint.substring(paramsStart, paramsEnd).trim(); + // Empty parameters + if (StringUtils.isNullOrEmpty(cacheParams)) { + LOGGER.warning("Empty CACHE_PARAM parameters"); + incrCounter(malformedHintCounter); + return null; + } + + // Parse comma-separated parameters + String[] params = cacheParams.split(","); + Integer ttlValue = null; + + for (String param : params) { + String[] keyValue = param.trim().split("="); + if (keyValue.length != 2) { + LOGGER.warning("Invalid caching parameter format: " + param); + incrCounter(malformedHintCounter); + return null; + } + String key = keyValue[0].trim(); + String value = keyValue[1].trim(); + + if ("TTL".equals(key)) { + if (!value.endsWith("S")) { + LOGGER.warning("TTL must end with 's': " + value); + incrCounter(malformedHintCounter); + return null; + } else{ + // Parse TTL value (e.g., "300s") try { - return Integer.parseInt(tokens[1].substring(4, endIndex)); - } catch (Exception e) { - LOGGER.warning("Encountered exception when parsing Cache TTL: " + e.getMessage()); + ttlValue = Integer.parseInt(value.substring(0, value.length() - 1)); + // treat negative and 0 ttls as not cacheable + if (ttlValue <= 0) { + return null; + } + } catch (NumberFormatException e) { + LOGGER.warning(String.format("Invalid TTL format of %s for query %s", value, queryHint)); + incrCounter(malformedHintCounter); + return null; } } } } - - LOGGER.finest("Query hint " + queryHint + " indicates the query is not cacheable"); - return null; + return ttlValue; } @Override @@ -181,8 +220,9 @@ public T execute( String mainQuery = sql; // The main part of the query with the query hint prefix trimmed int endOfQueryHint = 0; Integer configuredQueryTtl = null; - if ((sql.length() < 16000) && sql.startsWith("/*")) { - endOfQueryHint = sql.indexOf("*/"); + // Queries longer than 16KB is not cacheable + if ((sql.length() < 16000) && sql.startsWith(QUERY_HINT_START_PATTERN)) { + endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); if (endOfQueryHint > 0) { configuredQueryTtl = getTtlForQuery(sql.substring(2, endOfQueryHint).trim()); mainQuery = sql.substring(endOfQueryHint + 2).trim(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 8de187932..743461a96 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -34,6 +34,7 @@ public class DataRemoteCachePluginTest { @Mock TelemetryCounter mockHitCounter; @Mock TelemetryCounter mockMissCounter; @Mock TelemetryCounter mockTotalCallsCounter; + @Mock TelemetryCounter mockMalformedHintCounter; @Mock ResultSet mockResult1; @Mock Statement mockStatement; @Mock ResultSetMetaData mockMetaData; @@ -52,6 +53,7 @@ void setUp() throws SQLException { when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter); when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter); when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMalformedQueryHint")).thenReturn(mockMalformedHintCounter); when(mockResult1.getMetaData()).thenReturn(mockMetaData); when(mockMetaData.getColumnCount()).thenReturn(1); when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); @@ -66,36 +68,68 @@ void cleanUp() throws Exception { @Test void test_getTTLFromQueryHint() throws Exception { - // Null and empty query string are not cacheable + // Null and empty query hint content are not cacheable assertNull(plugin.getTtlForQuery(null)); assertNull(plugin.getTtlForQuery("")); assertNull(plugin.getTtlForQuery(" ")); - // Some other query hint - assertNull(plugin.getTtlForQuery("/* cacheNotEnabled */ select * from T")); - // Rule set is empty. All select queries are cached with 300 seconds TTL - String selectQuery1 = "cachettl=300s"; - String selectQuery2 = " /* CACHETTL=100s */ SELECT ID from mytable2 "; - String selectQuery3 = "/*CacheTTL=35s*/select * from table3 where ID = 1 and name = 'tom'"; - // Query hints that are not cacheable - String selectQueryNoHint = "select * from table4"; - String selectQueryNoCache1 = "no cache"; - String selectQueryNoCache2 = " /* NO CACHE */ SELECT count(*) FROM (select player_id from roster where id = 1 FOR UPDATE) really_long_name_alias"; - String selectQueryNoCache3 = "/* cachettl=300 */ SELECT count(*) FROM (select player_id from roster where id = 1) really_long_name_alias"; - - // Non select queries are not cacheable - String veryShortQuery = "BEGIN"; - String insertQuery = "/* This is an insert query */ insert into mytable values (1, 2)"; - String updateQuery = "/* Update query */ Update /* Another hint */ mytable set val = 1"; - assertEquals(300, plugin.getTtlForQuery(selectQuery1)); - assertEquals(100, plugin.getTtlForQuery(selectQuery2)); - assertEquals(35, plugin.getTtlForQuery(selectQuery3)); - assertNull(plugin.getTtlForQuery(selectQueryNoHint)); - assertNull(plugin.getTtlForQuery(selectQueryNoCache1)); - assertNull(plugin.getTtlForQuery(selectQueryNoCache2)); - assertNull(plugin.getTtlForQuery(selectQueryNoCache3)); - assertNull(plugin.getTtlForQuery(veryShortQuery)); - assertNull(plugin.getTtlForQuery(insertQuery)); - assertNull(plugin.getTtlForQuery(updateQuery)); + // Valid CACHE_PARAM cases - these are the hint contents after /*+ and before */ + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s)")); + assertEquals(100, plugin.getTtlForQuery("CACHE_PARAM(ttl=100s)")); + assertEquals(35, plugin.getTtlForQuery("CACHE_PARAM(ttl=35s)")); + + // Case insensitive + assertEquals(200, plugin.getTtlForQuery("cache_param(ttl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150s)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(tTl=200s)")); + assertEquals(150, plugin.getTtlForQuery("Cache_Param(ttl=150S)")); + assertEquals(200, plugin.getTtlForQuery("cache_param(TTL=200S)")); + + // CACHE_PARAM anywhere in hint content (mixed with other hint directives) + assertEquals(250, plugin.getTtlForQuery("INDEX(table1 idx1) CACHE_PARAM(ttl=250s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM(ttl=200s) USE_NL(t1 t2)")); + assertEquals(180, plugin.getTtlForQuery("FIRST_ROWS(10) CACHE_PARAM(ttl=180s) PARALLEL(4)")); + assertEquals(200, plugin.getTtlForQuery("foo=bar,CACHE_PARAM(ttl=200s),baz=qux")); + + // Whitespace handling + assertEquals(400, plugin.getTtlForQuery("CACHE_PARAM( ttl=400s )")); + assertEquals(500, plugin.getTtlForQuery("CACHE_PARAM(ttl = 500s)")); + assertEquals(200, plugin.getTtlForQuery("CACHE_PARAM( ttl = 200s , key = test )")); + + // Invalid cases - no CACHE_PARAM in hint content + assertNull(plugin.getTtlForQuery("INDEX(table1 idx1)")); + assertNull(plugin.getTtlForQuery("FIRST_ROWS(100)")); + assertNull(plugin.getTtlForQuery("cachettl=300s")); // old format + assertNull(plugin.getTtlForQuery("NO_CACHE")); + + // Missing parentheses + assertNull(plugin.getTtlForQuery("CACHE_PARAM ttl=300s")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300s")); + + // Multiple parameters (future-proofing) + assertEquals(300, plugin.getTtlForQuery("CACHE_PARAM(ttl=300s, key=test)")); + + // Large TTL values should work + assertEquals(999999, plugin.getTtlForQuery("CACHE_PARAM(ttl=999999s)")); + assertEquals(86400, plugin.getTtlForQuery("CACHE_PARAM(ttl=86400s)")); // 24 hours + } + + @Test + void test_getTTLFromQueryHint_MalformedHints() throws Exception { + // Test malformed cases + assertNull(plugin.getTtlForQuery("CACHE_PARAM()")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=abc)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=300)")); // missing 's' + + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(invalid_format)")); + + // Invalid TTL values (negative and zero) does not count toward malformed hints + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=0s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-10s)")); + assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=-1s)")); + + // Verify counter was incremented 8 times (5 original + 3 new) + verify(mockMalformedHintCounter, times(5)).inc(); } @Test @@ -125,7 +159,7 @@ void test_execute_noCachingLongQuery() throws Exception { when(mockCallable.call()).thenReturn(mockResult1); ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/* cacheTTL=30s */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); + methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); // Mock result set containing 1 row when(mockResult1.next()).thenReturn(true, true, false, false); @@ -153,7 +187,7 @@ void test_execute_cachingMissAndHit() throws Exception { when(mockResult1.getObject(1)).thenReturn("bar1"); ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/*CACHETTL=100s*/ select * from A"}); + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); // Cached result set contains 1 row assertTrue(rs.next()); @@ -163,7 +197,7 @@ void test_execute_cachingMissAndHit() throws Exception { byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{" /* CacheTtl=50s */select * from A"}); + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); assertTrue(rs2.next()); assertEquals("bar1", rs2.getString("fooName")); @@ -172,7 +206,7 @@ void test_execute_cachingMissAndHit() throws Exception { verify(mockPluginService, times(2)).isInTransaction(); verify(mockCacheConn, times(2)).readFromCache("public_user_select * from A"); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(100)); + verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); verify(mockTotalCallsCounter, times(2)).inc(); verify(mockMissCounter).inc(); verify(mockHitCounter).inc(); @@ -193,7 +227,7 @@ void test_transaction_cacheQuery() throws Exception { when(mockResult1.getObject(1)).thenReturn("bar1"); ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/* cacheTTL=300s */ select * from T"}); + methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s) */ select * from T"}); // Cached result set contains 1 row assertTrue(rs.next()); @@ -210,6 +244,66 @@ void test_transaction_cacheQuery() throws Exception { } @Test + void test_transaction_cacheQuery_multiple_query_params() throws Exception { + // Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{"/*+ CACHE_PARAM(ttl=300s, otherParam=abc) */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); + verify(mockTotalCallsCounter, never()).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + + @Test + void test_transaction_cacheQuery_multiple_query_hints() throws Exception {// Query is cacheable + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+ hello CACHE_PARAM(ttl=300s, otherParam=abc) world */ select * from T"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + verify(mockPluginService).getCurrentConnection(); + verify(mockPluginService).isInTransaction(); + verify(mockCacheConn, never()).readFromCache(anyString()); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); + verify(mockTotalCallsCounter, never()).inc(); + verify(mockHitCounter, never()).inc(); + verify(mockMissCounter, never()).inc(); + } + + @Test void test_transaction_noCaching() throws Exception { // Query is not cacheable when(mockPluginService.isInTransaction()).thenReturn(true); From b8d3db471e638becdccace4c55b42d786acd946a Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Thu, 21 Aug 2025 16:25:31 -0700 Subject: [PATCH 13/24] Caching - fetch and store the schema name for local session state to improve the caching performance. Renamed DataCacheConnectionPlugin to DataLocalCacheConnectionPlugin for clarity --- .../UsingTheJdbcDriver.md | 2 +- wrapper/build.gradle.kts | 4 +- .../jdbc/ConnectionPluginChainBuilder.java | 6 +-- .../amazon/jdbc/ConnectionPluginManager.java | 4 +- .../java/software/amazon/jdbc/Driver.java | 4 +- ...va => DataLocalCacheConnectionPlugin.java} | 10 ++-- ...ataLocalCacheConnectionPluginFactory.java} | 4 +- .../plugin/cache/DataRemoteCachePlugin.java | 15 ++++-- .../amazon/jdbc/util/WrapperUtils.java | 2 +- ..._advanced_jdbc_wrapper_messages.properties | 2 +- .../container/tests/DataCachePluginTests.java | 12 ++--- ...> DataLocalCacheConnectionPluginTest.java} | 10 ++-- .../cache/DataRemoteCachePluginTest.java | 53 ++++++++++++++----- 13 files changed, 82 insertions(+), 46 deletions(-) rename wrapper/src/main/java/software/amazon/jdbc/plugin/cache/{DataCacheConnectionPlugin.java => DataLocalCacheConnectionPlugin.java} (92%) rename wrapper/src/main/java/software/amazon/jdbc/plugin/cache/{DataCacheConnectionPluginFactory.java => DataLocalCacheConnectionPluginFactory.java} (86%) rename wrapper/src/test/java/software/amazon/jdbc/plugin/cache/{DataCacheConnectionPluginTest.java => DataLocalCacheConnectionPluginTest.java} (90%) diff --git a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md index dea51924f..b8069fc24 100644 --- a/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md +++ b/docs/using-the-jdbc-driver/UsingTheJdbcDriver.md @@ -220,7 +220,7 @@ The AWS JDBC Driver has several built-in plugins that are available to use. Plea [^2]: Federated Identity and Okta rely on IAM. Due to [^1], RDS Multi-AZ Clusters are not supported. > [!NOTE]\ -> To see information logged by plugins such as `DataCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. +> To see information logged by plugins such as `DataLocalCacheConnectionPlugin` and `LogQueryConnectionPlugin`, see the [Logging](#logging) section. In addition to the built-in plugins, you can also create custom plugins more suitable for your needs. For more information, see [Custom Plugins](../development-guide/LoadablePlugins.md#using-custom-plugins). diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index 36d1f0784..e42f4e4cb 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -211,7 +211,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) @@ -226,7 +226,7 @@ if (useJacoco) { "software/amazon/jdbc/wrapper/*", "software/amazon/jdbc/util/*", "software/amazon/jdbc/profile/*", - "software/amazon/jdbc/plugin/DataCacheConnectionPlugin*" + "software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin*" ) } })) diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 10209d291..ef07e611c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPluginFactory; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPluginFactory; import software.amazon.jdbc.plugin.ConnectTimeConnectionPluginFactory; -import software.amazon.jdbc.plugin.cache.DataCacheConnectionPluginFactory; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPluginFactory; import software.amazon.jdbc.plugin.cache.DataRemoteCachePluginFactory; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; @@ -69,7 +69,7 @@ public class ConnectionPluginChainBuilder { { put("executionTime", new ExecutionTimeConnectionPluginFactory()); put("logQuery", new LogQueryConnectionPluginFactory()); - put("dataCache", new DataCacheConnectionPluginFactory()); + put("dataCache", new DataLocalCacheConnectionPluginFactory()); put("dataRemoteCache", new DataRemoteCachePluginFactory()); put("customEndpoint", new CustomEndpointPluginFactory()); put("efm", new HostMonitoringConnectionPluginFactory()); @@ -102,7 +102,7 @@ public class ConnectionPluginChainBuilder { new HashMap, Integer>() { { put(DriverMetaDataConnectionPluginFactory.class, 100); - put(DataCacheConnectionPluginFactory.class, 200); + put(DataLocalCacheConnectionPluginFactory.class, 200); put(DataRemoteCachePluginFactory.class, 250); put(CustomEndpointPluginFactory.class, 380); put(AuroraInitialConnectionStrategyPluginFactory.class, 390); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 16b1f450b..b711f4617 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -33,7 +33,7 @@ import software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin; import software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin; import software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin; -import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; import software.amazon.jdbc.plugin.cache.DataRemoteCachePlugin; import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; @@ -73,7 +73,7 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(ExecutionTimeConnectionPlugin.class, "plugin:executionTime"); put(AuroraConnectionTrackerPlugin.class, "plugin:auroraConnectionTracker"); put(LogQueryConnectionPlugin.class, "plugin:logQuery"); - put(DataCacheConnectionPlugin.class, "plugin:dataCache"); + put(DataLocalCacheConnectionPlugin.class, "plugin:dataCache"); put(DataRemoteCachePlugin.class, "plugin:dataRemoteCache"); put(HostMonitoringConnectionPlugin.class, "plugin:efm"); put(software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin.class, "plugin:efm2"); diff --git a/wrapper/src/main/java/software/amazon/jdbc/Driver.java b/wrapper/src/main/java/software/amazon/jdbc/Driver.java index 673ced102..60db2c8f4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/Driver.java +++ b/wrapper/src/main/java/software/amazon/jdbc/Driver.java @@ -40,7 +40,7 @@ import software.amazon.jdbc.hostlistprovider.RdsHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; import software.amazon.jdbc.plugin.AwsSecretsManagerCacheHolder; -import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; import software.amazon.jdbc.plugin.OpenedConnectionTracker; import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; import software.amazon.jdbc.plugin.efm.HostMonitorThreadContainer; @@ -430,7 +430,7 @@ public static void clearCaches() { CustomEndpointMonitorImpl.clearCache(); OpenedConnectionTracker.clearCache(); AwsSecretsManagerCacheHolder.clearCache(); - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); FederatedAuthCacheHolder.clearCache(); OktaAuthCacheHolder.clearCache(); IamAuthCacheHolder.clearCache(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java similarity index 92% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java index 91a3f92e1..3d4743fea 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPlugin.java @@ -38,9 +38,9 @@ import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; -public class DataCacheConnectionPlugin extends AbstractConnectionPlugin { +public class DataLocalCacheConnectionPlugin extends AbstractConnectionPlugin { - private static final Logger LOGGER = Logger.getLogger(DataCacheConnectionPlugin.class.getName()); + private static final Logger LOGGER = Logger.getLogger(DataLocalCacheConnectionPlugin.class.getName()); private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( Arrays.asList( @@ -61,7 +61,7 @@ public class DataCacheConnectionPlugin extends AbstractConnectionPlugin { protected final String dataCacheTriggerCondition; static { - PropertyDefinition.registerPluginProperties(DataCacheConnectionPlugin.class); + PropertyDefinition.registerPluginProperties(DataLocalCacheConnectionPlugin.class); } private final TelemetryFactory telemetryFactory; @@ -70,7 +70,7 @@ public class DataCacheConnectionPlugin extends AbstractConnectionPlugin { private final TelemetryCounter totalCallsCounter; private final TelemetryGauge cacheSizeGauge; - public DataCacheConnectionPlugin(final PluginService pluginService, final Properties props) { + public DataLocalCacheConnectionPlugin(final PluginService pluginService, final Properties props) { this.telemetryFactory = pluginService.getTelemetryFactory(); this.dataCacheTriggerCondition = DATA_CACHE_TRIGGER_CONDITION.getString(props); @@ -120,7 +120,7 @@ public T execute( } LOGGER.finest( () -> Messages.get( - "DataCacheConnectionPlugin.queryResultsCached", + "DataLocalCacheConnectionPlugin.queryResultsCached", new Object[]{methodName, sql})); } else { if (this.hitCounter != null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java similarity index 86% rename from wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java rename to wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java index 1dfbe8f9d..c28e03e89 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginFactory.java @@ -21,10 +21,10 @@ import software.amazon.jdbc.ConnectionPluginFactory; import software.amazon.jdbc.PluginService; -public class DataCacheConnectionPluginFactory implements ConnectionPluginFactory { +public class DataLocalCacheConnectionPluginFactory implements ConnectionPluginFactory { @Override public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { - return new DataCacheConnectionPlugin(pluginService, props); + return new DataLocalCacheConnectionPlugin(pluginService, props); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 92e9dc400..80ad88ab4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -30,6 +30,7 @@ import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.WrapperUtils; @@ -89,13 +90,21 @@ private String getCacheQueryKey(String query) { try { Connection currentConn = pluginService.getCurrentConnection(); DatabaseMetaData metadata = currentConn.getMetaData(); + // Fetch and record the schema name if the session state doesn't currently have it + SessionStateService sessionStateService = pluginService.getSessionStateService(); + String schema = sessionStateService.getSchema().orElse(null); + if (schema == null) { + // Fetch the current schema name and store it in sessionStateService + schema = currentConn.getSchema(); + sessionStateService.setSchema(schema); + } + LOGGER.finest("DB driver protocol " + pluginService.getDriverProtocol() - + ", schema: " + currentConn.getSchema() + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() - + ", user: " + metadata.getUserName() + + ", schema: " + schema + ", user: " + metadata.getUserName() + ", driver: " + metadata.getDriverName() + " " + metadata.getDriverVersion()); // The cache key contains the schema name, user name, and the query string - String[] words = {currentConn.getSchema(), metadata.getUserName(), query}; + String[] words = {schema, metadata.getUserName(), query}; return String.join("_", words); } catch (SQLException e) { LOGGER.warning("Error getting session state: " + e.getMessage()); diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java index 0037bdbd9..837a90d5c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java @@ -573,7 +573,7 @@ public static Connection getConnectionFromSqlObject(final Object obj) { } } catch (final SQLException | UnsupportedOperationException e) { // Do nothing. The UnsupportedOperationException comes from ResultSets returned by - // DataCacheConnectionPlugin and will be triggered when getStatement is called. + // DataLocalCacheConnectionPlugin and will be triggered when getStatement is called. } return null; diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index ff67e7fd1..94afe5a1f 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -126,7 +126,7 @@ CustomEndpointPlugin.waitingForCustomEndpointInfo=Custom endpoint info for ''{0} CustomEndpointPluginFactory.awsSdkNotInClasspath=Required dependency 'AWS Java SDK RDS v2.x' is not on the classpath. -DataCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} +DataLocalCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} # Data Remote Cache Plugin DataRemoteCachePlugin.notInClassPath=Required dependency for DataRemoteCachePlugin is not on the classpath: ''{0}'' diff --git a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java index 68b9ad995..d6b77fee5 100644 --- a/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java +++ b/wrapper/src/test/java/integration/container/tests/DataCachePluginTests.java @@ -40,7 +40,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.plugin.cache.CachedResultSet; -import software.amazon.jdbc.plugin.cache.DataCacheConnectionPlugin; +import software.amazon.jdbc.plugin.cache.DataLocalCacheConnectionPlugin; @TestMethodOrder(MethodOrderer.MethodName.class) @ExtendWith(TestDriverProvider.class) @@ -58,20 +58,20 @@ public class DataCachePluginTests { @BeforeEach public void beforeEach() { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); } @TestTemplate public void testQueryCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*testTable.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); @@ -174,14 +174,14 @@ private void printTable() { @TestTemplate public void testQueryNotCacheable() throws SQLException { - DataCacheConnectionPlugin.clearCache(); + DataLocalCacheConnectionPlugin.clearCache(); final Properties props = ConnectionStringHelper.getDefaultProperties(); PropertyDefinition.CONNECT_TIMEOUT.set(props, "30000"); PropertyDefinition.SOCKET_TIMEOUT.set(props, "30000"); props.setProperty(PropertyDefinition.PLUGINS.name, "dataCache"); props.setProperty( - DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); + DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, ".*WRONG_EXPRESSION.*"); Connection conn = DriverManager.getConnection(ConnectionStringHelper.getWrapperUrl(), props); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java similarity index 90% rename from wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java rename to wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java index 6b9c39f22..cd307d2bf 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataCacheConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataLocalCacheConnectionPluginTest.java @@ -36,7 +36,7 @@ import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; -class DataCacheConnectionPluginTest { +class DataLocalCacheConnectionPluginTest { private static final Properties props = new Properties(); @@ -55,8 +55,8 @@ class DataCacheConnectionPluginTest { @BeforeEach void setUp() throws SQLException { closeable = MockitoAnnotations.openMocks(this); - props.setProperty(DataCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); - DataCacheConnectionPlugin.clearCache(); + props.setProperty(DataLocalCacheConnectionPlugin.DATA_CACHE_TRIGGER_CONDITION.name, "foo"); + DataLocalCacheConnectionPlugin.clearCache(); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); when(mockTelemetryFactory.createCounter(anyString())).thenReturn(mockTelemetryCounter); @@ -82,7 +82,7 @@ void cleanUp() throws Exception { void test_execute_withEmptyCache() throws SQLException { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); final ResultSet rs = plugin.execute( ResultSet.class, @@ -99,7 +99,7 @@ void test_execute_withEmptyCache() throws SQLException { void test_execute_withCache() throws Exception { final String methodName = "Statement.executeQuery"; - final DataCacheConnectionPlugin plugin = new DataCacheConnectionPlugin(mockPluginService, props); + final DataLocalCacheConnectionPlugin plugin = new DataLocalCacheConnectionPlugin(mockPluginService, props); when(mockCallable.call()).thenReturn(mockResult1, mockResult2); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 743461a96..39e352e79 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.when; import java.sql.*; +import java.util.Optional; import java.util.Properties; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; @@ -20,6 +21,7 @@ import org.mockito.MockitoAnnotations; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -39,6 +41,7 @@ public class DataRemoteCachePluginTest { @Mock Statement mockStatement; @Mock ResultSetMetaData mockMetaData; @Mock Connection mockConnection; + @Mock SessionStateService mockSessionStateService; @Mock DatabaseMetaData mockDbMetadata; @Mock CacheConnection mockCacheConn; @Mock JdbcCallable mockCallable; @@ -177,6 +180,8 @@ void test_execute_cachingMissAndHit() throws Exception { when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(false); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); when(mockConnection.getSchema()).thenReturn("public"); when(mockDbMetadata.getUserName()).thenReturn("user"); when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(null); @@ -205,6 +210,10 @@ void test_execute_cachingMissAndHit() throws Exception { verify(mockPluginService, times(3)).getCurrentConnection(); verify(mockPluginService, times(2)).isInTransaction(); verify(mockCacheConn, times(2)).readFromCache("public_user_select * from A"); + verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setSchema("public"); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); verify(mockTotalCallsCounter, times(2)).inc(); @@ -218,6 +227,8 @@ void test_transaction_cacheQuery() throws Exception { when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(true); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); when(mockConnection.getSchema()).thenReturn("public"); when(mockDbMetadata.getUserName()).thenReturn("user"); when(mockCallable.call()).thenReturn(mockResult1); @@ -235,6 +246,10 @@ void test_transaction_cacheQuery() throws Exception { assertFalse(rs.next()); verify(mockPluginService).getCurrentConnection(); verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getSchema(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setSchema("public"); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); @@ -249,6 +264,8 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(true); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); when(mockConnection.getSchema()).thenReturn("public"); when(mockDbMetadata.getUserName()).thenReturn("user"); when(mockCallable.call()).thenReturn(mockResult1); @@ -265,6 +282,10 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { assertFalse(rs.next()); verify(mockPluginService).getCurrentConnection(); verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getSchema(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setSchema("public"); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); @@ -275,19 +296,21 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { @Test void test_transaction_cacheQuery_multiple_query_hints() throws Exception {// Query is cacheable - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.isInTransaction()).thenReturn(true); - when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); - when(mockCallable.call()).thenReturn(mockResult1); - - // Result set contains 1 row - when(mockResult1.next()).thenReturn(true, false); - when(mockResult1.getObject(1)).thenReturn("bar1"); - - ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/*+ hello CACHE_PARAM(ttl=300s, otherParam=abc) world */ select * from T"}); + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(true); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+ hello CACHE_PARAM(ttl=300s, otherParam=abc) world */ select * from T"}); // Cached result set contains 1 row assertTrue(rs.next()); @@ -295,6 +318,10 @@ void test_transaction_cacheQuery_multiple_query_hints() throws Exception {// Que assertFalse(rs.next()); verify(mockPluginService).getCurrentConnection(); verify(mockPluginService).isInTransaction(); + verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getSchema(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setSchema("public"); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); From 9a5bfa67e42dba2a2bbd5cafe550c39f6c11135e Mon Sep 17 00:00:00 2001 From: Shaopeng Gu Date: Mon, 18 Aug 2025 10:43:49 -0700 Subject: [PATCH 14/24] Unit test for getter functions of CachedResultSet and added Timestamp support to CachedResultSet (PR #2) fix: Add Timestamp support to CachedResultSet getTime() and getDate() - Implement missing instanceof Timestamp cases in convertToTime() and convertToDate() - Use new Time(timestamp.getTime()) and new Date(timestamp.getTime()) for standard JDBC behavior - Add comprehensive test coverage for Timestamp conversion scenarios - Fix object equality issues in tests by using consistent constructors Resolves missing Timestamp handling that previously fell through to string parsing. added more tests and cleaned up debug logs removed uncessary imports, fixed testing logic for getTime() and getDate() and fixed implementation error for getBytes() and getBoolean() in the original codebase updated future timestamp testing to make it more robust intead of hardcoding the future timestamp updated convertToTime() to use static timezone offset -- functionality unchanged --- .../jdbc/plugin/cache/CachedResultSet.java | 23 +- .../plugin/cache/CachedResultSetTest.java | 365 +++++++++++++++++- 2 files changed, 382 insertions(+), 6 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index 6aa792c05..e7dafa9ac 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -38,6 +38,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Calendar; +import java.util.TimeZone; public class CachedResultSet implements ResultSet { @@ -68,6 +69,7 @@ public Object get(final int columnIndex) throws SQLException { protected boolean wasNullFlag; private final CachedResultSetMetaData metadata; protected static final ZoneId defaultTimeZoneId = ZoneId.systemDefault(); + protected static final TimeZone defaultTimeZone = TimeZone.getDefault(); private final HashMap columnNames; private volatile boolean closed; @@ -178,7 +180,7 @@ public boolean getBoolean(final int columnIndex) throws SQLException { final Object val = checkAndGetColumnValue(columnIndex); if (val == null) return false; if (val instanceof Boolean) return (Boolean) val; - if (val instanceof Number) return ((Number) val).intValue() == 0; + if (val instanceof Number) return ((Number) val).intValue() != 0; return Boolean.parseBoolean(val.toString()); } @@ -251,7 +253,8 @@ public byte[] getBytes(final int columnIndex) throws SQLException { final Object val = checkAndGetColumnValue(columnIndex); if (val == null) return null; if (val instanceof byte[]) return (byte[]) val; - return new byte[0]; + // Convert non-byte data to string, then to bytes (standard JDBC behavior) + return val.toString().getBytes(); } private Date convertToDate(Object dateObj, Calendar cal) throws SQLException { @@ -268,6 +271,14 @@ private Date convertToDate(Object dateObj, Calendar cal) throws SQLException { ZonedDateTime targetZonedDateTime = originalZonedDateTime.withZoneSameInstant(defaultTimeZoneId); return Date.valueOf(targetZonedDateTime.toLocalDate()); } + if (dateObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) dateObj; + long millis = timestamp.getTime(); + if (cal == null) return new Date(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Date(adjustedMillis); + } // Note: normally the user should properly store the Date object in the DB column and // the underlying PG/MySQL/MariaDB driver would convert it into Date already in getObject() @@ -302,6 +313,14 @@ private Time convertToTime(Object timeObj, Calendar cal) throws SQLException { OffsetTime localTime = ((OffsetTime)timeObj).withOffsetSameInstant(OffsetDateTime.now().getOffset()); return Time.valueOf(localTime.toLocalTime()); } + if (timeObj instanceof Timestamp) { + Timestamp timestamp = (Timestamp) timeObj; + long millis = timestamp.getTime(); + if (cal == null) return new Time(millis); + long adjustedMillis = millis - cal.getTimeZone().getOffset(millis) + + defaultTimeZone.getOffset(millis); + return new Time(adjustedMillis); + } // Note: normally the user should properly store the Time object in the DB column and // the underlying PG/MySQL/MariaDB driver would convert it into Time already in getObject() diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java index e22b7e13d..0b5ed376d 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -13,6 +13,8 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.net.URL; +import java.net.MalformedURLException; import java.math.BigDecimal; @@ -96,8 +98,8 @@ void cleanUp() { void setUpDefaultTestResultSet() throws SQLException { // Create the default CachedResultSet for testing when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); - when(mockResultSetMetadata.getColumnCount()).thenReturn(13); - for (int i = 0; i < 13; i++) { + when(mockResultSetMetadata.getColumnCount()).thenReturn(testColumnMetadata.length); + for (int i = 0; i < testColumnMetadata.length; i++) { mockGetMetadataFields(1+i, i); when(mockResultSet.getObject(1+i)).thenReturn(testColumnValues[i][0], testColumnValues[i][1]); } @@ -368,10 +370,15 @@ void test_get_special_time() throws SQLException { 4362000L, LocalTime.of(10, 20, 30), OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) + new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()), // Future Date: next year same date at 9:30 AM "15:34:20", "InvalidTime", null); - when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, false); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, + true, true, true, true, false); CachedResultSet cachedRs = new CachedResultSet(mockResultSet); // Time from a number @@ -385,6 +392,32 @@ void test_get_special_time() throws SQLException { assertTrue(cachedRs.next()); assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1)); assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1, estCal)); + // Time from Timestamp + assertTrue(cachedRs.next()); + Timestamp timestampOne = new Timestamp(1755621000000L); + // Compare underlying millis + assertEquals(timestampOne.getTime(), cachedRs.getTime(1).getTime()); + // Compare logical wall-clock time + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Time from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp timestampTwo = new Timestamp(1735713000000L); + assertEquals(timestampTwo.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(22, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(19, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Epoch time of 0 + assertTrue(cachedRs.next()); + assertEquals(new Time(0), cachedRs.getTime(1)); + assertEquals(0L, cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(16, 0, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(13, 0, 0), cachedRs.getTime(1, estCal).toLocalTime()); + // Future date + assertTrue(cachedRs.next()); + Timestamp futureTimestamp = new Timestamp(Timestamp.valueOf(LocalDateTime.now().plusYears(1).withHour(9).withMinute(30).withSecond(0).withNano(0)).getTime()); + assertEquals(futureTimestamp.getTime(), cachedRs.getTime(1).getTime()); + assertEquals(LocalTime.of(9, 30, 0), cachedRs.getTime(1).toLocalTime()); + assertEquals(LocalTime.of(6, 30, 0), cachedRs.getTime(1, estCal).toLocalTime()); // Timestamp from String assertTrue(cachedRs.next()); assertEquals(Time.valueOf("15:34:20"), cachedRs.getTime(1)); @@ -411,10 +444,16 @@ void test_get_special_date() throws SQLException { 1515944311000L, -1000000000L, LocalDate.of(2010, 10, 30), + new Timestamp(1755621000000L), // Date and time (GMT): Tuesday, August 19, 2025 4:30:00 PM + new Timestamp(1735713000000L), // Date and time (GMT): Wednesday, January 1, 2025 6:30:00 AM + new Timestamp(1755673200000L), // Date and time (GMT): Wednesday, August 20, 2025 7:00:00 AM --> PDT Aug 20 12AM + new Timestamp(1735718400000L), // Date and time (GMT): Wednesday, January 1, 2025 8:00:00 AM --> PST Jan 1 12AM + new Timestamp(0L), // 1970-01-01 00:00:00 UTC (epoch) "2025-03-15", "InvalidDate", null); - when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, false); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, + true, true, true, true, false); CachedResultSet cachedRs = new CachedResultSet(mockResultSet); // Date from a number @@ -424,9 +463,38 @@ void test_get_special_date() throws SQLException { assertTrue(cachedRs.next()); assertEquals(new Date(-1000000000L), cachedRs.getDate(1)); // Date from LocalDate + assertTrue(cachedRs.next()); assertEquals(Date.valueOf("2010-10-30"), cachedRs.getDate(1)); assertEquals(Date.valueOf("2010-10-29"), cachedRs.getDate(1, estCal)); + // Date from Timestamp + assertTrue(cachedRs.next()); + Timestamp tsForDate1 = new Timestamp(1755621000000L); + assertEquals(new Date(tsForDate1.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025, 8, 19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate2 = new Timestamp(1735713000000L); + assertEquals(new Date(tsForDate2.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024, 12, 31), cachedRs.getDate(1, estCal).toLocalDate()); + // Date from Timestamp Edge Case + assertTrue(cachedRs.next()); + Timestamp tsForDate3 = new Timestamp(1755673200000L); + assertEquals(new Date(tsForDate3.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,8,20), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2025,8,19), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate4 = new Timestamp(1735718400000L); + assertEquals(new Date(tsForDate4.getTime()), cachedRs.getDate(1)); + assertEquals(LocalDate.of(2025,1,1), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(2024,12,31), cachedRs.getDate(1, estCal).toLocalDate()); + assertTrue(cachedRs.next()); + Timestamp tsForDate5 = new Timestamp(0L); + assertEquals(new Date(tsForDate5.getTime()), cachedRs.getDate(1)); + assertEquals(new Date(0L), cachedRs.getDate(1)); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1).toLocalDate()); + assertEquals(LocalDate.of(1969,12,31), cachedRs.getDate(1, estCal).toLocalDate()); // Date from String assertTrue(cachedRs.next()); assertEquals(Date.valueOf("2025-03-15"), cachedRs.getDate(1)); @@ -442,4 +510,293 @@ void test_get_special_date() throws SQLException { assertTrue(cachedRs.next()); assertNull(cachedRs.getDate(1)); } + + @Test + void test_get_nstring() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + when(mockResultSet.getObject(1)).thenReturn("test string", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test string value - both index and label versions + assertTrue(cachedRs.next()); + assertEquals("test string", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + assertEquals("test string", cachedRs.getNString("fieldString")); + assertFalse(cachedRs.wasNull()); + + // Test number conversion + assertTrue(cachedRs.next()); + assertEquals("123", cachedRs.getNString(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getNString(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_bytes() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 4); + // Test data + byte[] testBytes = {1, 2, 3, 4, 5}; + when(mockResultSet.getObject(1)).thenReturn(testBytes, "not bytes", 123, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test bytes values - both index and label versions + assertTrue(cachedRs.next()); + assertArrayEquals(testBytes, cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + assertArrayEquals(testBytes, cachedRs.getBytes("fieldByte")); + assertFalse(cachedRs.wasNull()); + + // Test non-byte array input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("not bytes".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test number input (should convert to bytes) + assertTrue(cachedRs.next()); + assertArrayEquals("123".getBytes(), cachedRs.getBytes(1)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getBytes(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_boolean() throws SQLException { + // Setup single column with String metadata + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 3); + // Test data: boolean, numbers, strings, null + when(mockResultSet.getObject(1)).thenReturn( + true, false, 0, 1, -5, "true", "false", "invalid", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual boolean values - both index and label versions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + assertTrue(cachedRs.getBoolean("fieldBoolean")); + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); + assertFalse(cachedRs.wasNull()); + + // Test number conversions: 0 = true, non-zero = false + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // 0 → false + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // 1 → true + + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // -5 → true + + // Test string conversions + assertTrue(cachedRs.next()); + assertTrue(cachedRs.getBoolean(1)); // "true" → true + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "false" → false + + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // "invalid" → false (parseBoolean) + + // Test null handling + assertTrue(cachedRs.next()); + assertFalse(cachedRs.getBoolean(1)); // null → false + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_URL() throws SQLException { + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: URL object, valid URL string, invalid URL string, null + // URL object setup + URL testUrl = null; + try { + testUrl = new URL("https://example.com"); + } catch (MalformedURLException e) { + fail("Test setup failed"); + } + + when(mockResultSet.getObject(1)).thenReturn( + testUrl, "https://valid.com", "invalid-url", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual URL object - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testUrl, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + assertEquals(testUrl, cachedRs.getURL("fieldString")); + + // Test valid URL string conversion + assertTrue(cachedRs.next()); + URL validURL = null; + try { + validURL = new URL("https://valid.com"); + } catch (MalformedURLException e) { + fail("Failed setting up new valid URL"); + } + assertEquals(validURL, cachedRs.getURL(1)); + assertFalse(cachedRs.wasNull()); + + // Test invalid URL string (should throw SQLException) + assertTrue(cachedRs.next()); + assertThrows(SQLException.class, () -> cachedRs.getURL(1)); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getURL(1)); + assertTrue(cachedRs.wasNull()); + } + + @Test + void test_get_object_with_index_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, null + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject(1, String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject(1, Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject(1, Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject(1, String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_get_object_with_label_and_type() throws SQLException { + // Setup single column with string metadata (mixed data types) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + // Test data: string, integer, boolean, HashSet (unsupported type), null + HashSet testSet = new HashSet<>(); + testSet.add("item1"); + testSet.add("item2"); + + when(mockResultSet.getObject(1)).thenReturn("test", 123, true, testSet, null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid type conversions + assertTrue(cachedRs.next()); + assertEquals("test", cachedRs.getObject("fieldString", String.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Integer.valueOf(123), cachedRs.getObject("fieldString", Integer.class)); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertEquals(Boolean.TRUE, cachedRs.getObject("fieldString", Boolean.class)); + assertFalse(cachedRs.wasNull()); + + // Test unsupported data type (HashSet) - should work with getObject() + assertTrue(cachedRs.next()); + HashSet retrievedSet = cachedRs.getObject("fieldString", HashSet.class); + assertEquals(testSet, retrievedSet); + assertFalse(cachedRs.wasNull()); + + // Test null handling + assertTrue(cachedRs.next()); + assertNull(cachedRs.getObject("fieldString", String.class)); + assertTrue(cachedRs.wasNull()); + + // Test invalid type conversion (should throw ClassCastException) + cachedRs.beforeFirst(); + // Wraps around + assertTrue(cachedRs.next()); + assertThrows(ClassCastException.class, () -> cachedRs.getObject(1, Integer.class)); + } + + @Test + void test_unwrap() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid unwrap to ResultSet interface + ResultSet unwrappedResultSet = cachedRs.unwrap(ResultSet.class); + assertSame(cachedRs, unwrappedResultSet); + + // Test valid unwrap to CachedResultSet class + CachedResultSet unwrappedCachedResultSet = cachedRs.unwrap(CachedResultSet.class); + assertSame(cachedRs, unwrappedCachedResultSet); + + // Test invalid unwrap attempts should throw SQLException + assertThrows(SQLException.class, () -> cachedRs.unwrap(String.class)); + assertThrows(SQLException.class, () -> cachedRs.unwrap(Integer.class)); + } + + @Test + void test_is_wrapper_for() throws SQLException { + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 2); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test valid wrapper checks + assertTrue(cachedRs.isWrapperFor(ResultSet.class)); + assertTrue(cachedRs.isWrapperFor(CachedResultSet.class)); + + // Test invalid wrapper checks + assertFalse(cachedRs.isWrapperFor(String.class)); + assertFalse(cachedRs.isWrapperFor(Integer.class)); + + // Test null class parameter + assertFalse(cachedRs.isWrapperFor(null)); + } } + From 54098683a655fe57a5a1143aced5b8bce7c9c6d8 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Wed, 27 Aug 2025 13:06:44 -0700 Subject: [PATCH 15/24] Support getSQLXML() for CachedResultSet. Update CacheConnection for better error logging. --- .../jdbc/plugin/cache/CacheConnection.java | 8 +- .../jdbc/plugin/cache/CachedResultSet.java | 14 +- .../jdbc/plugin/cache/CachedSQLXML.java | 118 ++++++++++++ .../plugin/cache/CachedResultSetTest.java | 71 ++++++- .../jdbc/plugin/cache/CachedSQLXMLTest.java | 174 ++++++++++++++++++ 5 files changed, 374 insertions(+), 11 deletions(-) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index ed55a3555..64c39798f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -35,10 +35,10 @@ public class CacheConnection { private final String cacheRoServerAddr; // read-only cache server private MessageDigest msgHashDigest = null; - private static final int DEFAULT_POOL_SIZE = 10; - private static final int DEFAULT_POOL_MAX_IDLE = 10; + private static final int DEFAULT_POOL_SIZE = 20; + private static final int DEFAULT_POOL_MAX_IDLE = 20; private static final int DEFAULT_POOL_MIN_IDLE = 0; - private static final long DEFAULT_MAX_BORROW_WAIT_MS = 50; + private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100; private static final ReentrantLock READ_LOCK = new ReentrantLock(); private static final ReentrantLock WRITE_LOCK = new ReentrantLock(); @@ -233,7 +233,7 @@ public void writeToCache(String key, byte[] value, int expiry) { .whenComplete((result, exception) -> handleCompletedCacheWrite(finalConn, exception)); } catch (Exception e) { // Failed to trigger the async write to the cache, return the cache connection to the pool as broken - LOGGER.warning("Failed to write to cache: " + e.getMessage()); + LOGGER.warning("Unable to start writing to cache: " + e.getMessage()); if (conn != null && writeConnectionPool != null) { try { returnConnectionBackToPool(conn, true, false); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index e7dafa9ac..73b4d2d2c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -89,7 +89,12 @@ public CachedResultSet(final ResultSet resultSet) throws SQLException { while (resultSet.next()) { final CachedRow row = new CachedRow(numColumns); for (int i = 1; i <= numColumns; ++i) { - row.put(i, resultSet.getObject(i)); + Object rowObj = resultSet.getObject(i); + // For SQLXML object, convert into CachedSQLXML object that is serializable + if (rowObj instanceof SQLXML) { + rowObj = new CachedSQLXML(((SQLXML)rowObj).getString()); + } + row.put(i, rowObj); } rows.add(row); } @@ -1143,12 +1148,15 @@ public NClob getNClob(final String columnLabel) throws SQLException { @Override public SQLXML getSQLXML(final int columnIndex) throws SQLException { - throw new UnsupportedOperationException(); + Object val = checkAndGetColumnValue(columnIndex); + if (val == null) return null; + if (val instanceof SQLXML) return (SQLXML) val; + return new CachedSQLXML(val.toString()); } @Override public SQLXML getSQLXML(final String columnLabel) throws SQLException { - throw new UnsupportedOperationException(); + return getSQLXML(checkAndGetColumnIndex(columnLabel)); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java new file mode 100644 index 000000000..a49240172 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSQLXML.java @@ -0,0 +1,118 @@ +package software.amazon.jdbc.plugin.cache; + +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.XMLReaderFactory; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLInputFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Result; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.io.Writer; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.sql.SQLXML; + +public class CachedSQLXML implements SQLXML, Serializable { + private boolean freed; + private String data; + + public CachedSQLXML(String data) { + this.data = data; + this.freed = false; + } + + @Override + public void free() throws SQLException { + if (this.freed) return; + this.data = null; + this.freed = true; + } + + private void checkFreed() throws SQLException { + if (this.freed) { + throw new SQLException("This SQLXML object has already been freed."); + } + } + + @Override + public InputStream getBinaryStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new ByteArrayInputStream(this.data.getBytes(StandardCharsets.UTF_8)); + } + + @Override + public OutputStream setBinaryStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public Reader getCharacterStream() throws SQLException { + checkFreed(); + if (this.data == null) return null; + return new StringReader(this.data); + } + + @Override + public Writer setCharacterStream() throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public String getString() throws SQLException { + checkFreed(); + return this.data; + } + + @Override + public void setString(String value) throws SQLException { + throw new UnsupportedOperationException(); + } + + @Override + public T getSource(Class sourceClass) throws SQLException { + checkFreed(); + if (this.data == null) return null; + + try { + if (sourceClass == null || DOMSource.class.equals(sourceClass)) { + DocumentBuilder builder = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + return (T) new DOMSource(builder.parse(new InputSource(new StringReader(data)))); + } + + if (SAXSource.class.equals(sourceClass)) { + XMLReader reader = XMLReaderFactory.createXMLReader(); + return sourceClass.cast(new SAXSource(reader, new InputSource(new StringReader(data)))); + } + + if (StreamSource.class.equals(sourceClass)) { + return sourceClass.cast(new StreamSource(new StringReader(data))); + } + + if (StAXSource.class.equals(sourceClass)) { + XMLStreamReader xsr = XMLInputFactory.newFactory().createXMLStreamReader(new StringReader(data)); + return sourceClass.cast(new StAXSource(xsr)); + } + throw new SQLException("Unsupported source class for XML data: " + sourceClass.getName()); + } catch (Exception e) { + throw new SQLException("Unable to decode XML data.", e); + } + } + + @Override + public T setResult(Class resultClass) throws SQLException { + throw new UnsupportedOperationException(); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java index 0b5ed376d..00ac96e34 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -23,8 +23,8 @@ public class CachedResultSetTest { @Mock ResultSet mockResultSet; @Mock ResultSetMetaData mockResultSetMetadata; private AutoCloseable closeable; - private static Calendar estCal = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); - private TimeZone defaultTimeZone = TimeZone.getDefault(); + private static final Calendar estCal = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + private final TimeZone defaultTimeZone = TimeZone.getDefault(); // Column values: label, name, typeName, type, displaySize, precision, tableName, // scale, schemaName, isAutoIncrement, isCaseSensitive, isCurrency, isDefinitelyWritable, @@ -42,7 +42,8 @@ public class CachedResultSetTest { {"fieldBigDecimal", "fieldBigDecimal", "BigDecimal", Types.DECIMAL, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, {"fieldDate", "fieldDate", "Date", Types.DATE, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, {"fieldTime", "fieldTime", "Time", Types.TIME, 10, 2, "table", 1, "public", false, false, false, false, 1, true, true, false, false}, - {"fieldDateTime", "fieldDateTime", "Timestamp", Types.TIMESTAMP, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false} + {"fieldDateTime", "fieldDateTime", "Timestamp", Types.TIMESTAMP, 10, 2, "table", 1, "public", false, false, false, false, 0, true, true, false, false}, + {"fieldSqlXml", "fieldSqlXml", "SqlXml", Types.SQLXML, 100, 1, "table", 1, "public", false, false, false, false, 0, true, true, false, false} }; private static final Object [][] testColumnValues = { @@ -58,7 +59,8 @@ public class CachedResultSetTest { {new BigDecimal("15.33"), new BigDecimal("-12.45")}, {Date.valueOf("2025-03-15"), Date.valueOf("1102-01-15")}, {Time.valueOf("22:54:00"), Time.valueOf("01:10:00")}, - {Timestamp.valueOf("2025-03-15 22:54:00"), Timestamp.valueOf("1950-01-18 21:50:05")} + {Timestamp.valueOf("2025-03-15 22:54:00"), Timestamp.valueOf("1950-01-18 21:50:05")}, + {new CachedSQLXML("A"), new CachedSQLXML("Value AValue B")} }; private void mockGetMetadataFields(int column, int testMetadataCol) throws SQLException { @@ -202,6 +204,12 @@ private void verifyDefaultRow(ResultSet rs, int row) throws SQLException { assertEquals(testColumnValues[12][row], rs.getTimestamp("fieldDateTime")); assertEquals(13, rs.findColumn("fieldDateTime")); assertFalse(rs.wasNull()); + String sqlXmlString = ((SQLXML)testColumnValues[13][row]).getString(); + assertEquals(sqlXmlString, rs.getSQLXML(14).getString()); // fieldSqlXml + assertFalse(rs.wasNull()); + assertEquals(sqlXmlString, rs.getSQLXML("fieldSqlXml").getString()); + assertEquals(14, rs.findColumn("fieldSqlXml")); + assertFalse(rs.wasNull()); verifyNonexistingField(rs); } @@ -672,6 +680,61 @@ void test_get_URL() throws SQLException { assertTrue(cachedRs.wasNull()); } + @Test + void test_get_sql_xml() throws SQLException { + String longXml = + "\n" + + " TechCorp\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + ""; + SQLXML testXml = new CachedSQLXML("PostgreSQL GuideJohn Doe"); + SQLXML testXml2 = new CachedSQLXML(longXml); + SQLXML invalidXml = new CachedSQLXML("A"); + // Setup single column with string metadata (URLs stored as strings) + when(mockResultSet.getMetaData()).thenReturn(mockResultSetMetadata); + when(mockResultSetMetadata.getColumnCount()).thenReturn(1); + mockGetMetadataFields(1, 13); + when(mockResultSet.getObject(1)).thenReturn(testXml, testXml2, invalidXml, "invalid-xml", null); + when(mockResultSet.next()).thenReturn(true, true, true, true, true, false); + + CachedResultSet cachedRs = new CachedResultSet(mockResultSet); + + // Test actual SQLXML objects - both index and label versions + assertTrue(cachedRs.next()); + assertEquals(testXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(testXml2.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML(1).getString()); + assertFalse(cachedRs.wasNull()); + assertEquals(invalidXml.getString(), cachedRs.getSQLXML("fieldSqlXml").getString()); + + assertTrue(cachedRs.next()); + assertEquals("invalid-xml", cachedRs.getSQLXML(1).getString()); + assertEquals("invalid-xml", cachedRs.getSQLXML("fieldSqlXml").getString()); + assertFalse(cachedRs.wasNull()); + + assertTrue(cachedRs.next()); + assertNull(cachedRs.getSQLXML(1)); + assertTrue(cachedRs.wasNull()); + assertNull(cachedRs.getSQLXML("fieldSqlXml")); + assertTrue(cachedRs.wasNull()); + + assertFalse(cachedRs.next()); + } + + @Test void test_get_object_with_index_and_type() throws SQLException { // Setup single column with string metadata (mixed data types) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java new file mode 100644 index 000000000..7340ac1b0 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedSQLXMLTest.java @@ -0,0 +1,174 @@ +package software.amazon.jdbc.plugin.cache; + +import org.junit.jupiter.api.Test; +import org.w3c.dom.*; +import org.xml.sax.Attributes; +import org.xml.sax.InputSource; +import org.xml.sax.XMLReader; +import org.xml.sax.helpers.DefaultHandler; +import java.io.InputStream; +import java.io.Reader; +import java.sql.SQLException; +import java.sql.SQLXML; + +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.stream.XMLStreamReader; +import javax.xml.transform.Source; +import javax.xml.transform.dom.DOMSource; +import javax.xml.transform.sax.SAXSource; +import javax.xml.transform.stax.StAXSource; +import javax.xml.transform.stream.StreamSource; + +import static org.junit.jupiter.api.Assertions.*; + +public class CachedSQLXMLTest { + + @Test + void test_basic_XML() throws Exception { + String xml = "Value AValue B"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // Test binary stream + byte[] array = new byte[100]; + InputStream stream = sqlxml.getBinaryStream(); + assertEquals(xml.length(), stream.available()); + assertTrue(stream.read(array) > 0); + assertEquals(xml, new String(array, 0, xml.length())); + stream.close(); + + // Test character stream + char[] chars = new char[100]; + Reader reader = sqlxml.getCharacterStream(); + assertTrue(reader.read(chars) > 0); + assertEquals(xml, new String(chars, 0, xml.length())); + reader.close(); + + // Test free() + sqlxml.free(); + assertThrows(SQLException.class, sqlxml::getString); + assertThrows(SQLException.class, sqlxml::getCharacterStream); + assertThrows(SQLException.class, sqlxml::getBinaryStream); + assertThrows(SQLException.class, () -> sqlxml.getSource(DOMSource.class)); + } + + private void validateDOMElement(Document document, String elementName, String elementValue) { + NodeList elements = document.getElementsByTagName(elementName); + assertEquals(1, elements.getLength()); + Element element = (Element) elements.item(0); + assertEquals(elementName, element.getNodeName()); + assertEquals(elementValue, element.getTextContent()); + } + + private void validateSimpleDocument(Document document) { + Element rootElement = document.getDocumentElement(); + assertEquals("product", rootElement.getNodeName()); + NodeList elements = document.getElementsByTagName("product"); + assertEquals(1, elements.getLength()); // product has 3 elements + elements = document.getElementsByTagName("specs"); + assertEquals(1, elements.getLength()); // specs has 3 elements + validateDOMElement(document, "manufacturer", "TechCorp"); + validateDOMElement(document, "cpu", "Intel i7"); + validateDOMElement(document, "ram", "16GB"); + validateDOMElement(document, "storage", "512GB SSD"); + validateDOMElement(document, "price", "1200.00"); + } + + static private void validateDocElements(String name, String value) { + if (name.equalsIgnoreCase("manufacturer")) { + assertEquals("TechCorp", value); + } else if (name.equalsIgnoreCase("cpu")) { + assertEquals("Intel i7", value); + } else if (name.equalsIgnoreCase("ram")) { + assertEquals("16GB", value); + } else if (name.equalsIgnoreCase("storage")) { + assertEquals("512GB SSD", value); + } else if (name.equalsIgnoreCase("price")) { + assertEquals("1200.00", value); + } + } + + static private class XmlReaderContentHandler extends DefaultHandler { + private StringBuilder currentValue; + + @Override + public void startElement(String uri, String localName, String qName, Attributes attributes) { + currentValue = new StringBuilder(); // Reset for each new element + } + + @Override + public void endElement(String uri, String localName, String qName) { + // Verify the element's value + String value = currentValue.toString().trim(); + validateDocElements(qName, value); + } + + @Override + public void characters(char[] ch, int start, int length) { + currentValue.append(ch, start, length); + } + } + + @Test + void test_getSource_XML() throws Exception { + // Test parsing a more complex XML via getSource() + String xml = " \n" + + "\n" + + " TechCorp\n\n" + + "\n" + + " Intel i7\n" + + " 16GB\n" + + " 512GB SSD\n" + + "\n" + + " 1200.00\n" + + "\n"; + SQLXML sqlxml = new CachedSQLXML(xml); + assertEquals(xml, sqlxml.getString()); + + // DOM source + DOMSource domSource = sqlxml.getSource(null); + Node node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + domSource = sqlxml.getSource(DOMSource.class); + node = domSource.getNode(); + assertEquals(Node.DOCUMENT_NODE, node.getNodeType()); + validateSimpleDocument((Document) node); + + // SAX source + SAXSource src = sqlxml.getSource(SAXSource.class); + XMLReader xmlReader = src.getXMLReader(); + xmlReader.setContentHandler(new XmlReaderContentHandler()); + xmlReader.parse(src.getInputSource()); + + // Streams source + StreamSource xmlSource = sqlxml.getSource(StreamSource.class); + DocumentBuilder db = DocumentBuilderFactory.newInstance().newDocumentBuilder(); + Document doc = db.parse(new InputSource(xmlSource.getReader())); + doc.getDocumentElement().normalize(); + validateSimpleDocument(doc); + + // StAX Source + StAXSource staxSource = sqlxml.getSource(StAXSource.class); + XMLStreamReader sReader = staxSource.getXMLStreamReader(); + String elementName = ""; + StringBuilder elementValue = new StringBuilder(); + while (sReader.hasNext()) { + int event = sReader.next(); + if (event == XMLStreamReader.START_ELEMENT) { + elementName = sReader.getLocalName(); + } else if (event == XMLStreamReader.CHARACTERS) { + elementValue.append(sReader.getText()); + } else if (event == XMLStreamReader.END_ELEMENT) { + validateDocElements(elementName, elementValue.toString().trim()); + elementName = ""; + elementValue = new StringBuilder(); + } + } + sReader.close(); // Close the reader when done + + // Invalid source class + assertThrows(SQLException.class, () -> sqlxml.getSource(Source.class)); + } +} From 758b8299efd04626e227d305ab584f6a0f87c6c8 Mon Sep 17 00:00:00 2001 From: Shaopeng Gu Date: Tue, 26 Aug 2025 13:48:01 -0700 Subject: [PATCH 16/24] added metrics JdbcCachedQueryCount, JdbcCacheMissCount, JdbcCacheBypassCount, and JdbcCachedQueryAfterUpdate (PR #3) JdbcCachedQueryAfterUpdate feature removed Add OpenTelemetry integration for cache latency monitoring fixed variable naming, added one more context for latency tracking took out cacheMissContext and removed dbContext setSuccess logging removed redundant variables handled closing cacheContext for SQLException and removed dbContext error handling minor syntax fix updated unit tests for the latest telemetry metrics --- .../plugin/cache/DataRemoteCachePlugin.java | 87 ++++-- .../cache/DataRemoteCachePluginTest.java | 270 +++++++++++++----- 2 files changed, 268 insertions(+), 89 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 80ad88ab4..fac8423e9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -36,12 +36,17 @@ import software.amazon.jdbc.util.WrapperUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); + private static final int MAX_CACHEABLE_QUERY_SIZE = 16000; private static final String QUERY_HINT_START_PATTERN = "/*+"; private static final String QUERY_HINT_END_PATTERN = "*/"; private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; + private static final String TELEMETRY_CACHE_LOOKUP = "jdbc-cache-lookup"; + private static final String TELEMETRY_DATABASE_QUERY = "jdbc-database-query"; private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>( Arrays.asList(JdbcMethod.STATEMENT_EXECUTEQUERY.methodName, JdbcMethod.STATEMENT_EXECUTE.methodName, @@ -52,10 +57,11 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private PluginService pluginService; private TelemetryFactory telemetryFactory; - private TelemetryCounter hitCounter; - private TelemetryCounter missCounter; - private TelemetryCounter totalCallsCounter; + private TelemetryCounter cacheHitCounter; + private TelemetryCounter cacheMissCounter; + private TelemetryCounter totalQueryCounter; private TelemetryCounter malformedHintCounter; + private TelemetryCounter cacheBypassCounter; private CacheConnection cacheConnection; public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { @@ -67,10 +73,11 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties } this.pluginService = pluginService; this.telemetryFactory = pluginService.getTelemetryFactory(); - this.hitCounter = telemetryFactory.createCounter("remoteCache.cache.hit"); - this.missCounter = telemetryFactory.createCounter("remoteCache.cache.miss"); - this.totalCallsCounter = telemetryFactory.createCounter("remoteCache.cache.totalCalls"); + this.cacheHitCounter = telemetryFactory.createCounter("JdbcCachedQueryCount"); + this.cacheMissCounter = telemetryFactory.createCounter("JdbcCacheMissCount"); + this.totalQueryCounter = telemetryFactory.createCounter("JdbcCacheTotalQueryCount"); this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint"); + this.cacheBypassCounter = telemetryFactory.createCounter("JdbcCacheBypassCount"); this.cacheConnection = new CacheConnection(properties); } @@ -223,6 +230,8 @@ public T execute( boolean needToCache = false; final String sql = getQuery(jdbcMethodArgs); + TelemetryContext cacheContext = null; + TelemetryContext dbContext = null; // If the query is cacheable, we try to fetch the query result from the cache. boolean isInTransaction = pluginService.isInTransaction(); // Get the query hint part in front of the query itself @@ -230,7 +239,7 @@ public T execute( int endOfQueryHint = 0; Integer configuredQueryTtl = null; // Queries longer than 16KB is not cacheable - if ((sql.length() < 16000) && sql.startsWith(QUERY_HINT_START_PATTERN)) { + if ((sql.length() < MAX_CACHEABLE_QUERY_SIZE) && sql.startsWith(QUERY_HINT_START_PATTERN)) { endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); if (endOfQueryHint > 0) { configuredQueryTtl = getTtlForQuery(sql.substring(2, endOfQueryHint).trim()); @@ -238,30 +247,60 @@ public T execute( } } + incrCounter(totalQueryCounter); + // Query result can be served from the cache if it has a configured TTL value, and it is // not executed in a transaction as a transaction typically need to return consistent results. if (!isInTransaction && (configuredQueryTtl != null)) { - incrCounter(totalCallsCounter); - result = fetchResultSetFromCache(mainQuery); - if (result == null) { - // Cache miss. Need to fetch result from the database - needToCache = true; - incrCounter(missCounter); - LOGGER.finest("Got a cache miss for SQL: " + sql); - } else { - LOGGER.finest("Got a cache hit for SQL: " + sql); - // Cache hit. Return the cached result - incrCounter(hitCounter); - try { - result.beforeFirst(); - } catch (final SQLException ex) { - throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + cacheContext = telemetryFactory.openTelemetryContext( + TELEMETRY_CACHE_LOOKUP, TelemetryTraceLevel.TOP_LEVEL); + Exception cacheException = null; + try{ + result = fetchResultSetFromCache(mainQuery); + if (result == null) { + // Cache miss. Need to fetch result from the database + needToCache = true; + incrCounter(cacheMissCounter); + LOGGER.finest("Got a cache miss for SQL: " + sql); + } else { + LOGGER.finest("Got a cache hit for SQL: " + sql); + // Cache hit. Return the cached result + incrCounter(cacheHitCounter); + try { + result.beforeFirst(); + } catch (final SQLException ex) { + cacheException = ex; + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, ex); + } + return resultClass.cast(result); + } + } finally { + if (cacheContext != null) { + if (cacheException != null) { + cacheContext.setSuccess(false); + cacheContext.setException(cacheException); + cacheContext.closeContext(); + } else if (!needToCache) { // Cache hit + cacheContext.setSuccess(true); + cacheContext.closeContext(); + } else { // Cache miss - leave context open + cacheContext.setSuccess(false); + } } - return resultClass.cast(result); } + } else { + incrCounter(cacheBypassCounter); } - result = (ResultSet) jdbcMethodFunc.call(); + dbContext = telemetryFactory.openTelemetryContext( + TELEMETRY_DATABASE_QUERY, TelemetryTraceLevel.TOP_LEVEL); + + try { + result = (ResultSet) jdbcMethodFunc.call(); + } finally { + if (dbContext != null) dbContext.closeContext(); + if (cacheContext != null) cacheContext.closeContext(); + } // We need to cache the query result if we got a cache miss for the query result, // or the query is cacheable and executed inside a transaction. diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 39e352e79..9d6941a60 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -22,8 +22,10 @@ import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.states.SessionStateService; +import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; public class DataRemoteCachePluginTest { private static final Properties props = new Properties(); @@ -33,10 +35,12 @@ public class DataRemoteCachePluginTest { private DataRemoteCachePlugin plugin; @Mock PluginService mockPluginService; @Mock TelemetryFactory mockTelemetryFactory; - @Mock TelemetryCounter mockHitCounter; - @Mock TelemetryCounter mockMissCounter; - @Mock TelemetryCounter mockTotalCallsCounter; + @Mock TelemetryCounter mockCacheHitCounter; + @Mock TelemetryCounter mockCacheMissCounter; + @Mock TelemetryCounter mockTotalQueryCounter; @Mock TelemetryCounter mockMalformedHintCounter; + @Mock TelemetryCounter mockCacheBypassCounter; + @Mock TelemetryContext mockTelemetryContext; @Mock ResultSet mockResult1; @Mock Statement mockStatement; @Mock ResultSetMetaData mockMetaData; @@ -53,10 +57,12 @@ void setUp() throws SQLException { props.setProperty("cacheEndpointAddrRw", "localhost:6379"); when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); - when(mockTelemetryFactory.createCounter("remoteCache.cache.hit")).thenReturn(mockHitCounter); - when(mockTelemetryFactory.createCounter("remoteCache.cache.miss")).thenReturn(mockMissCounter); - when(mockTelemetryFactory.createCounter("remoteCache.cache.totalCalls")).thenReturn(mockTotalCallsCounter); + when(mockTelemetryFactory.createCounter("JdbcCachedQueryCount")).thenReturn(mockCacheHitCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheMissCount")).thenReturn(mockCacheMissCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheTotalQueryCount")).thenReturn(mockTotalQueryCounter); when(mockTelemetryFactory.createCounter("JdbcCacheMalformedQueryHint")).thenReturn(mockMalformedHintCounter); + when(mockTelemetryFactory.createCounter("JdbcCacheBypassCount")).thenReturn(mockCacheBypassCounter); + when(mockTelemetryFactory.openTelemetryContext(anyString(), any())).thenReturn(mockTelemetryContext); when(mockResult1.getMetaData()).thenReturn(mockMetaData); when(mockMetaData.getColumnCount()).thenReturn(1); when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); @@ -145,14 +151,19 @@ void test_execute_noCaching() throws Exception { methodName, mockCallable, new String[]{"select * from mytable where ID = 2"}); // Mock result set containing 1 row - when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.next()).thenReturn(true, true, false); when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); compareResults(mockResult1, rs); verify(mockPluginService).isInTransaction(); verify(mockCallable).call(); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); } @Test @@ -165,13 +176,18 @@ void test_execute_noCachingLongQuery() throws Exception { methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); // Mock result set containing 1 row - when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.next()).thenReturn(true, true, false); when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); compareResults(mockResult1, rs); verify(mockCallable).call(); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); } @Test @@ -198,9 +214,11 @@ void test_execute_cachingMissAndHit() throws Exception { assertTrue(rs.next()); assertEquals("bar1", rs.getString("fooName")); assertFalse(rs.next()); + rs.beforeFirst(); byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); @@ -216,9 +234,19 @@ void test_execute_cachingMissAndHit() throws Exception { verify(mockSessionStateService).setSchema("public"); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); - verify(mockTotalCallsCounter, times(2)).inc(); - verify(mockMissCounter).inc(); - verify(mockHitCounter).inc(); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(1)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + // First call: Cache miss + Database call + verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true)) + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(3)).closeContext(); } @Test @@ -253,9 +281,16 @@ void test_transaction_cacheQuery() throws Exception { verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); } @Test @@ -289,64 +324,169 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); } - @Test - void test_transaction_cacheQuery_multiple_query_hints() throws Exception {// Query is cacheable - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockPluginService.isInTransaction()).thenReturn(true); - when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); - when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); - when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); - when(mockCallable.call()).thenReturn(mockResult1); - - // Result set contains 1 row - when(mockResult1.next()).thenReturn(true, false); - when(mockResult1.getObject(1)).thenReturn("bar1"); - - ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/*+ hello CACHE_PARAM(ttl=300s, otherParam=abc) world */ select * from T"}); - - // Cached result set contains 1 row - assertTrue(rs.next()); - assertEquals("bar1", rs.getString("fooName")); - assertFalse(rs.next()); - verify(mockPluginService).getCurrentConnection(); - verify(mockPluginService).isInTransaction(); - verify(mockPluginService).getSessionStateService(); - verify(mockSessionStateService).getSchema(); - verify(mockConnection).getSchema(); - verify(mockSessionStateService).setSchema("public"); - verify(mockCacheConn, never()).readFromCache(anyString()); - verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); - } - - @Test + @Test void test_transaction_noCaching() throws Exception { // Query is not cacheable when(mockPluginService.isInTransaction()).thenReturn(true); when(mockCallable.call()).thenReturn(mockResult1); ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"delete from mytable"}); + "Statement.execute", mockCallable, new String[]{"delete from mytable"}); // Mock result set containing 1 row - when(mockResult1.next()).thenReturn(true, true, false, false); + when(mockResult1.next()).thenReturn(true, true, false); when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); compareResults(mockResult1, rs); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); - verify(mockTotalCallsCounter, never()).inc(); - verify(mockHitCounter, never()).inc(); - verify(mockMissCounter, never()).inc(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_malformed_hint() throws Exception { + // Setup - not in transaction with malformed cache hint + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + // Query with malformed cache hint - should increment both malformed and bypass counters + String queryWithMalformedHint = "/*+ CACHE_PARAM(ttl=invalid) */ SELECT * FROM users WHERE id = 123"; + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{queryWithMalformedHint}); + // Verify malformed counter incremented first + verify(mockMalformedHintCounter, times(1)).inc(); + // Verify bypass counter incremented (because configuredQueryTtl becomes null) + verify(mockCacheBypassCounter, times(1)).inc(); + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_JdbcCacheBypassCount_double_bypass_prevention() throws Exception { + // Setup - query that meets MULTIPLE bypass conditions + when(mockPluginService.isInTransaction()).thenReturn(true); // Bypass condition #1 + when(mockCallable.call()).thenReturn(mockResult1); + + // Add mocks for caching flow (needed for transaction caching) + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("testuser"); + + // Mock result set for caching + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("testdata"); + + // Query that is BOTH too large AND in transaction - double bypass conditions + String largeQueryInTransaction = "/*+ CACHE_PARAM(ttl=300s) */ SELECT * FROM table WHERE data = '" + + RandomStringUtils.randomAlphanumeric(16000) + "'"; // >16KB AND in transaction + + // Execute + plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{largeQueryInTransaction}); + + // Verify bypass counter incremented EXACTLY ONCE (not twice) + verify(mockCacheBypassCounter, times(1)).inc(); + + // Verify cache flow counters were NOT called + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheMissCounter, never()).inc(); + + // Verify malformed counter not called (hint is valid, just large query) + verify(mockMalformedHintCounter, never()).inc(); + // Verify TelemetryContext behavior for transaction scenario + // In transaction: No cache lookup attempted, only database call + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Context closure: Only 1 database context + verify(mockTelemetryContext, times(1)).closeContext(); + } + + @Test + void test_execute_multipleCacheHits() throws Exception { + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); + when(mockConnection.getSchema()).thenReturn("public"); + when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{"/*+CACHE_PARAM(ttl=50s)*/ select * from A"}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); + + for (int i = 0; i < 10; i ++) { + ResultSet cur_rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); + + assertTrue(cur_rs.next()); + assertEquals("bar1", cur_rs.getString("fooName")); + assertFalse(cur_rs.next()); + } + + verify(mockPluginService, times(12)).getCurrentConnection(); + verify(mockPluginService, times(11)).isInTransaction(); + verify(mockCacheConn, times(11)).readFromCache("public_user_select * from A"); + verify(mockPluginService, times(12)).getSessionStateService(); + verify(mockSessionStateService, times(12)).getSchema(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setSchema("public"); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(11)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(10)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + verify(mockTelemetryFactory, times(11)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(10)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(12)).closeContext(); } void compareResults(final ResultSet expected, final ResultSet actual) throws SQLException { From c60c8b42831a30dd63bd56d14f57d5fa4cd7ea64 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Mon, 25 Aug 2025 11:20:21 -0700 Subject: [PATCH 17/24] Caching performance testing program on Postgres. The postgres test table contains 11 columns containing different data types. The example program populates the table with 400K records, each record has > 1KB of data, for a total of ~520MB of data in the table. It then performs SELECT queries continuously across all 400K rows. postgres=# CREATE TABLE test (id SERIAL PRIMARY KEY, int_col INTEGER, varchar_col varchar(50) NOT NULL, text_col TEXT, num_col DOUBLE PRECISION, date_col date, time_col TIME WITHOUT TIME ZONE, time_tz TIME WITH TIME ZONE, ts_col TIMESTAMP WITHOUT TIME ZONE, ts_tz TIMESTAMP WITH TIME ZONE), description TEXT; CREATE TABLE postgres=# select * from test; id | int_col | varchar_col | text_col | num_col | date_col | time_col | time_tz | ts_col | ts_tz | description ----+---------+-------------+----------+---------+----------+----------+---------+--------+-------+-------------- (0 rows) Suppress debug logs and only log at INFO level and above for example program. --- benchmarks/build.gradle.kts | 4 + .../jdbc/benchmarks/PgCacheBenchmarks.java | 125 ++++++++++++++++++ .../src/main/resources/logback.xml | 6 + 3 files changed, 135 insertions(+) create mode 100644 benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java create mode 100644 examples/AWSDriverExample/src/main/resources/logback.xml diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index 359765fe5..09c2809ae 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -25,6 +25,10 @@ dependencies { implementation("org.mariadb.jdbc:mariadb-java-client:3.5.6") implementation("com.zaxxer:HikariCP:4.0.3") implementation("org.checkerframework:checker-qual:3.49.5") + implementation("io.lettuce:lettuce-core:6.6.0.RELEASE") + implementation("org.apache.commons:commons-pool2:2.11.1") + annotationProcessor("org.openjdk.jmh:jmh-core:1.36") + jmhAnnotationProcessor ("org.openjdk.jmh:jmh-generator-annprocess:1.36") testImplementation("org.junit.jupiter:junit-jupiter-api:5.12.2") testImplementation("org.mockito:mockito-inline:4.11.0") // 4.11.0 is the last version compatible with Java 8 diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java new file mode 100644 index 000000000..b728da5dd --- /dev/null +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java @@ -0,0 +1,125 @@ +package software.amazon.jdbc.benchmarks; + +import org.openjdk.jmh.annotations.*; +import java.sql.*; +import java.util.*; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.profile.GCProfiler; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * Performance benchmark program against PG. + * + * This test program runs JMH benchmark tests the performance of the remote cache plugin against a + * a remote PG database and a remote cache server for both indexed queries and non-indexed queries. + * + * The database table schema is as follows: + * + * postgres=# CREATE TABLE test (id SERIAL PRIMARY KEY, int_col INTEGER, varchar_col varchar(50) NOT NULL, text_col TEXT, + * num_col DOUBLE PRECISION, date_col date, time_col TIME WITHOUT TIME ZONE, time_tz TIME WITH TIME ZONE, + * ts_col TIMESTAMP WITHOUT TIME ZONE, ts_tz TIMESTAMP WITH TIME ZONE, description TEXT); + * CREATE TABLE + * postgres=# select * from test; + * id | int_col | varchar_col | text_col | num_col | date_col | time_col | time_tz | ts_col | ts_tz | description + * ----+---------+-------------+----------+---------+----------+----------+---------+--------+-------+-------------- + * (0 rows) + * + */ +@State(Scope.Thread) +@Fork(1) +@Warmup(iterations = 1) +@Measurement(iterations = 60, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class PgCacheBenchmarks { + private static final String DB_CONNECTION_STRING = "jdbc:aws-wrapper:postgresql://db-0.XYZ.us-east-2.rds.amazonaws.com:5432/postgres"; + private static final String CACHE_RW_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6379"; + private static final String CACHE_RO_SERVER_ADDR = "cache-0.XYZ.us-east-2.rds.amazonaws.com:6380"; + + private Connection connection; + private int counter; + long startTime; + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(PgCacheBenchmarks.class.getSimpleName()) + .addProfiler(GCProfiler.class) + .detectJvmArgs() + .build(); + + new Runner(opt).run(); + } + + @Setup(Level.Trial) + public void setup() throws SQLException { + try { + software.amazon.jdbc.Driver.register(); + } catch (IllegalStateException e) { + System.out.println("exception during register() is " + e.getMessage()); + } + Properties properties = new Properties(); + properties.setProperty("wrapperPlugins", "dataRemoteCache"); + properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); + properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + properties.setProperty("wrapperLogUnclosedConnections", "true"); + counter = 0; + connection = DriverManager.getConnection(DB_CONNECTION_STRING, properties); + startTime = System.currentTimeMillis(); + } + + @TearDown(Level.Trial) + public void tearDown() throws SQLException { + connection.close(); + } + + // Code to warm up the data in the table + public void warmUpDataSet() throws SQLException { + String desc_1KB = "mP48pHrR5vreBo3N6ecmlDgvfEAz0kQEOUQ89U3Rh05BTG9LhB8R0HBFBp5RIqc8vVcrphu89kW1OE2c2xApwpczFMdDAuk2SxOl9OrLvfk9zGYrdfzedcepT8LVeE6NTtYDeP3yo6UFC6AiOeqRBY5NEaNcZ8fuoXVpqOrqAhz910v5XrFxeXUyPDFxuaKFLaHfEFq7BRasUc9nfhP8gblKAGfEEmgYBpUKio27Rfo0xnavfVJQkAA2kME2PT4qZRSqeDkLmn7VBAzT9ghHqe9D4kQLQKjIyIPKqYoS8kW3ShW44VqYENwPSRAXw7UqOJqlKJ4pnmx4sPZO2kI4NYOl1JZXNlbGaSzJR0cOloKiY0z2OmUNvmD0Wju1DC9TT4OY6a6DOfFvk265BfDVxT6ufN68YG9sZuVsl7jq8SZSJg3x2cqlJuAtdSTIoKmJT1a6cEXxVusmdO27kRRp1BfWR4gz4w9HawYf9nBQOq76ObctlNvj0fYUUG3I49s3iP33CL8qZjj9RnyNUus6ieiZgta6L3mZuMRYOgCLyJrAKUYEL9KND7qirCPzVgmJHWIOnVewu8mldYFhroL89yvV3bZx4MGeyPU4KvbCsRgdORCTN0XhuLYUdiehHXnDBfuZ5yyR0saWLh8gjkLV5GkxTeKpOhpoK1o1cMiCDPYqTa64g5JundlW707c9zxc3Xnf2pW7E74YJl5oBu5vWEyPqXtYOtZOjOIRxxDY8QpoW8mpbQXxgB8DjkZZMiUCe0qHZYxvktVZJmHoaYBwpYpXVTZCfq9WajmkIOdIad1VnH5HpaECLRs6loa259yH8qesak2feDiKjfb8p3uj3s7WZUvPJwAWX9PIW1p7x6OiszXQCntOFRC3bQFNz1c98wlCBJnBSxbbYhU057TDNnoaib1h9bH7LAcqD1caE5KwLMAc5HqugkkRzT5NszkdJcpF0SxakdrAQLOKS6sNwDUzBJA76F775vmaqe3XIYecPmGtfoAKMychfEI4vfNr"; + for (int i = 0; i < 400000; i++) { + Statement stmt = connection.createStatement(); + String description = "description " + i; + String text = "here is my text data " + i; + String query = "insert into test values (" + i + ", " + i * 10 + ", '" + description + "', '" + text + "', " + i * 100 + 0.1234 + ", '2024-01-10', '10:00:00', '10:00:00-07', '2025-07-15 10:00:00', '2025-07-15 10:00:00-07'" + ", '" + desc_1KB + "');"; + int rs = stmt.executeUpdate(query); + assert rs == 1; + } + } + + private void validateResultSet(ResultSet rs) throws SQLException { + while (rs.next()) { + assert rs.getInt(1) >= 0; + assert rs.getInt(2) >= 0; + assert rs.getString(3) != null; + assert rs.getString(4) != null; + assert rs.getDouble(5) >= 0.0; + assert rs.getDate(6) != null; + assert rs.getTime(7) != null; + assert rs.getTime(8) != null; + assert rs.getTimestamp(9) != null; + assert rs.getTimestamp(10) != null; + assert !rs.wasNull(); + } + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookup() throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where id = " + counter)) { + validateResultSet(rs); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookup() throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs); + } + counter++; + } +} diff --git a/examples/AWSDriverExample/src/main/resources/logback.xml b/examples/AWSDriverExample/src/main/resources/logback.xml new file mode 100644 index 000000000..e03eaf554 --- /dev/null +++ b/examples/AWSDriverExample/src/main/resources/logback.xml @@ -0,0 +1,6 @@ + + + + + + From de27d6faf5d602cb93212374706f64be4029c865 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Mon, 22 Sep 2025 12:00:12 -0700 Subject: [PATCH 18/24] Address some review comments. Fix perf testing program. --- .../jdbc/benchmarks/PgCacheBenchmarks.java | 51 ++++++---- wrapper/build.gradle.kts | 5 +- .../jdbc/plugin/cache/CacheConnection.java | 2 +- .../plugin/cache/DataRemoteCachePlugin.java | 44 +++++++-- .../cache/DataRemoteCachePluginTest.java | 96 +++++++++++++------ 5 files changed, 140 insertions(+), 58 deletions(-) diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java index b728da5dd..8c18d4c44 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PgCacheBenchmarks.java @@ -1,6 +1,7 @@ package software.amazon.jdbc.benchmarks; import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; import java.sql.*; import java.util.*; import java.util.concurrent.TimeUnit; @@ -89,36 +90,54 @@ public void warmUpDataSet() throws SQLException { } } - private void validateResultSet(ResultSet rs) throws SQLException { + private void validateResultSet(ResultSet rs, Blackhole b) throws SQLException { while (rs.next()) { - assert rs.getInt(1) >= 0; - assert rs.getInt(2) >= 0; - assert rs.getString(3) != null; - assert rs.getString(4) != null; - assert rs.getDouble(5) >= 0.0; - assert rs.getDate(6) != null; - assert rs.getTime(7) != null; - assert rs.getTime(8) != null; - assert rs.getTimestamp(9) != null; - assert rs.getTimestamp(10) != null; - assert !rs.wasNull(); + b.consume(rs.getInt(1)); + b.consume(rs.getInt(2)); + b.consume(rs.getString(3)); + b.consume(rs.getString(4)); + b.consume(rs.getDouble(5)); + b.consume(rs.getDate(6)); + b.consume(rs.getTime(7)); + b.consume(rs.getTime(8)); + b.consume(rs.getTimestamp(9)); + b.consume(rs.getTimestamp(10)); + b.consume(rs.wasNull()); } } @Benchmark - public void runBenchmarkPrimaryKeyLookup() throws SQLException { + public void runBenchmarkPrimaryKeyLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where id = " + counter)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkNonIndexedLookupNoCaching(Blackhole b) throws SQLException { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM test where int_col = " + counter*10)) { + validateResultSet(rs, b); + } + counter++; + } + + @Benchmark + public void runBenchmarkPrimaryKeyLookupWithCaching(Blackhole b) throws SQLException { try (Statement stmt = connection.createStatement(); ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where id = " + counter)) { - validateResultSet(rs); + validateResultSet(rs, b); } counter++; } @Benchmark - public void runBenchmarkNonIndexedLookup() throws SQLException { + public void runBenchmarkNonIndexedLookupWithCaching(Blackhole b) throws SQLException { try (Statement stmt = connection.createStatement(); ResultSet rs = stmt.executeQuery("/*+ CACHE_PARAM(ttl=172800s) */ SELECT * FROM test where int_col = " + counter*10)) { - validateResultSet(rs); + validateResultSet(rs, b); } counter++; } diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index e42f4e4cb..7de1e94d1 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -44,14 +44,14 @@ dependencies { optionalImplementation("com.mchange:c3p0:0.11.0") optionalImplementation("org.apache.httpcomponents:httpclient:4.5.14") optionalImplementation("com.fasterxml.jackson.core:jackson-databind:2.19.0") + optionalImplementation("org.apache.commons:commons-pool2:2.11.1") optionalImplementation("org.jsoup:jsoup:1.21.1") optionalImplementation("com.amazonaws:aws-xray-recorder-sdk-core:2.18.2") + optionalImplementation("io.lettuce:lettuce-core:6.6.0.RELEASE") optionalImplementation("io.opentelemetry:opentelemetry-api:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk:1.52.0") optionalImplementation("io.opentelemetry:opentelemetry-sdk-metrics:1.52.0") - compileOnly("io.lettuce:lettuce-core:6.6.0.RELEASE") - compileOnly("org.apache.commons:commons-pool2:2.11.1") compileOnly("org.checkerframework:checker-qual:3.49.5") compileOnly("com.mysql:mysql-connector-j:9.4.0") compileOnly("org.postgresql:postgresql:42.7.7") @@ -110,6 +110,7 @@ dependencies { testImplementation("de.vandermeer:asciitable:0.3.2") testImplementation("org.hibernate:hibernate-core:5.6.15.Final") // the latest version compatible with Java 8 testImplementation("jakarta.persistence:jakarta.persistence-api:2.2.3") + testImplementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.19.2") } repositories { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 64c39798f..fca6008eb 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -116,7 +116,7 @@ private void createConnectionPool(boolean isRead) { String[] hostnameAndPort = serverAddr.split(":"); RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0]) .withPort(Integer.parseInt(hostnameAndPort[1])) - .withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-jdbc-lettuce").build(); + .withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-sql-jdbc-lettuce").build(); RedisClient client = RedisClient.create(resources, redisUriCluster); GenericObjectPool> pool = new GenericObjectPool<>( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index fac8423e9..cf04ccdd2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -26,9 +26,11 @@ import java.util.Properties; import java.util.Set; import java.util.logging.Logger; +import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.JdbcCallable; import software.amazon.jdbc.JdbcMethod; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.plugin.AbstractConnectionPlugin; import software.amazon.jdbc.states.SessionStateService; import software.amazon.jdbc.util.Messages; @@ -41,7 +43,6 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); - private static final int MAX_CACHEABLE_QUERY_SIZE = 16000; private static final String QUERY_HINT_START_PATTERN = "/*+"; private static final String QUERY_HINT_END_PATTERN = "*/"; private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; @@ -55,6 +56,7 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { JdbcMethod.CALLABLESTATEMENT_EXECUTE.methodName, JdbcMethod.CALLABLESTATEMENT_EXECUTEQUERY.methodName))); + private int maxCacheableQuerySize; private PluginService pluginService; private TelemetryFactory telemetryFactory; private TelemetryCounter cacheHitCounter; @@ -63,6 +65,17 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private TelemetryCounter malformedHintCounter; private TelemetryCounter cacheBypassCounter; private CacheConnection cacheConnection; + private String dbUserName; + + private static final AwsWrapperProperty CACHE_MAX_QUERY_SIZE = + new AwsWrapperProperty( + "cacheMaxQuerySize", + "16384", + "The max query size for remote caching"); + + static { + PropertyDefinition.registerPluginProperties(DataRemoteCachePlugin.class); + } public DataRemoteCachePlugin(final PluginService pluginService, final Properties properties) { try { @@ -78,7 +91,9 @@ public DataRemoteCachePlugin(final PluginService pluginService, final Properties this.totalQueryCounter = telemetryFactory.createCounter("JdbcCacheTotalQueryCount"); this.malformedHintCounter = telemetryFactory.createCounter("JdbcCacheMalformedQueryHint"); this.cacheBypassCounter = telemetryFactory.createCounter("JdbcCacheBypassCount"); + this.maxCacheableQuerySize = CACHE_MAX_QUERY_SIZE.getInteger(properties); this.cacheConnection = new CacheConnection(properties); + this.dbUserName = PropertyDefinition.USER.getString(properties); } // Used for unit testing purposes only @@ -93,25 +108,36 @@ public Set getSubscribedMethods() { private String getCacheQueryKey(String query) { // Check some basic session states. The important ones for caching include (but not limited to): - // schema name, username which can affect the query result from the DB in addition to the query string + // schema name, username which can affect the query result from the DB in addition to the query string try { Connection currentConn = pluginService.getCurrentConnection(); DatabaseMetaData metadata = currentConn.getMetaData(); // Fetch and record the schema name if the session state doesn't currently have it SessionStateService sessionStateService = pluginService.getSessionStateService(); + String catalog = sessionStateService.getCatalog().orElse(null); String schema = sessionStateService.getSchema().orElse(null); - if (schema == null) { + if (catalog == null && schema == null) { // Fetch the current schema name and store it in sessionStateService + catalog = currentConn.getCatalog(); schema = currentConn.getSchema(); - sessionStateService.setSchema(schema); + if (catalog != null) sessionStateService.setCatalog(catalog); + if (schema != null) sessionStateService.setSchema(schema); } + if (dbUserName == null) { + // For MySQL, metadata username is actually @. We just need the part before '@'. + dbUserName = metadata.getUserName(); + int nameIndexEnd = dbUserName.indexOf('@'); + if (nameIndexEnd > 0) { + dbUserName = dbUserName.substring(0, nameIndexEnd); + } + } LOGGER.finest("DB driver protocol " + pluginService.getDriverProtocol() + ", database product: " + metadata.getDatabaseProductName() + " " + metadata.getDatabaseProductVersion() - + ", schema: " + schema + ", user: " + metadata.getUserName() + + ", catalog: " + catalog + ", schema: " + schema + ", user: " + dbUserName + ", driver: " + metadata.getDriverName() + " " + metadata.getDriverVersion()); // The cache key contains the schema name, user name, and the query string - String[] words = {schema, metadata.getUserName(), query}; + String[] words = {catalog, schema, dbUserName, query}; return String.join("_", words); } catch (SQLException e) { LOGGER.warning("Error getting session state: " + e.getMessage()); @@ -239,11 +265,11 @@ public T execute( int endOfQueryHint = 0; Integer configuredQueryTtl = null; // Queries longer than 16KB is not cacheable - if ((sql.length() < MAX_CACHEABLE_QUERY_SIZE) && sql.startsWith(QUERY_HINT_START_PATTERN)) { + if ((sql.length() < maxCacheableQuerySize) && sql.startsWith(QUERY_HINT_START_PATTERN)) { endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); if (endOfQueryHint > 0) { - configuredQueryTtl = getTtlForQuery(sql.substring(2, endOfQueryHint).trim()); - mainQuery = sql.substring(endOfQueryHint + 2).trim(); + configuredQueryTtl = getTtlForQuery(sql.substring(QUERY_HINT_START_PATTERN.length(), endOfQueryHint).trim()); + mainQuery = sql.substring(endOfQueryHint + QUERY_HINT_END_PATTERN.length()).trim(); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 9d6941a60..54b466e50 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -28,7 +28,7 @@ import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; public class DataRemoteCachePluginTest { - private static final Properties props = new Properties(); + private Properties props; private final String methodName = "Statement.executeQuery"; private AutoCloseable closeable; @@ -53,9 +53,9 @@ public class DataRemoteCachePluginTest { @BeforeEach void setUp() throws SQLException { closeable = MockitoAnnotations.openMocks(this); + props = new Properties(); props.setProperty("wrapperPlugins", "dataRemoteCache"); props.setProperty("cacheEndpointAddrRw", "localhost:6379"); - when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); when(mockTelemetryFactory.createCounter("JdbcCachedQueryCount")).thenReturn(mockCacheHitCounter); when(mockTelemetryFactory.createCounter("JdbcCacheMissCount")).thenReturn(mockCacheMissCounter); @@ -66,8 +66,6 @@ void setUp() throws SQLException { when(mockResult1.getMetaData()).thenReturn(mockMetaData); when(mockMetaData.getColumnCount()).thenReturn(1); when(mockMetaData.getColumnLabel(1)).thenReturn("fooName"); - plugin = new DataRemoteCachePlugin(mockPluginService, props); - plugin.setCacheConnection(mockCacheConn); } @AfterEach @@ -77,6 +75,8 @@ void cleanUp() throws Exception { @Test void test_getTTLFromQueryHint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Null and empty query hint content are not cacheable assertNull(plugin.getTtlForQuery(null)); assertNull(plugin.getTtlForQuery("")); @@ -124,6 +124,8 @@ void test_getTTLFromQueryHint() throws Exception { @Test void test_getTTLFromQueryHint_MalformedHints() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Test malformed cases assertNull(plugin.getTtlForQuery("CACHE_PARAM()")); assertNull(plugin.getTtlForQuery("CACHE_PARAM(ttl=abc)")); @@ -143,6 +145,8 @@ void test_getTTLFromQueryHint_MalformedHints() throws Exception { @Test void test_execute_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is not cacheable when(mockPluginService.isInTransaction()).thenReturn(false); when(mockCallable.call()).thenReturn(mockResult1); @@ -168,6 +172,8 @@ void test_execute_noCaching() throws Exception { @Test void test_execute_noCachingLongQuery() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is not cacheable when(mockPluginService.isInTransaction()).thenReturn(false); when(mockCallable.call()).thenReturn(mockResult1); @@ -192,15 +198,19 @@ void test_execute_noCachingLongQuery() throws Exception { @Test void test_execute_cachingMissAndHit() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is not cacheable when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(false); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); - when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); - when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(null); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql")); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("mysql"); + when(mockConnection.getSchema()).thenReturn(null); + when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1"); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null); when(mockCallable.call()).thenReturn(mockResult1); // Result set contains 1 row @@ -217,7 +227,7 @@ void test_execute_cachingMissAndHit() throws Exception { rs.beforeFirst(); byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); - when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet); ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockStatement, methodName, mockCallable, new String[]{" /*+CACHE_PARAM(ttl=50s)*/select * from A"}); @@ -227,13 +237,16 @@ void test_execute_cachingMissAndHit() throws Exception { assertFalse(rs2.next()); verify(mockPluginService, times(3)).getCurrentConnection(); verify(mockPluginService, times(2)).isInTransaction(); - verify(mockCacheConn, times(2)).readFromCache("public_user_select * from A"); + verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A"); verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getCatalog(); verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getCatalog(); verify(mockConnection).getSchema(); - verify(mockSessionStateService).setSchema("public"); + verify(mockSessionStateService).setCatalog("mysql"); + verify(mockDbMetadata).getUserName(); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); + verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50)); verify(mockTotalQueryCounter, times(2)).inc(); verify(mockCacheMissCounter, times(1)).inc(); verify(mockCacheHitCounter, times(1)).inc(); @@ -251,14 +264,18 @@ void test_execute_cachingMissAndHit() throws Exception { @Test void test_transaction_cacheQuery() throws Exception { + props.setProperty("user", "dbuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is cacheable when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(true); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("postgres"); when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); when(mockCallable.call()).thenReturn(mockResult1); // Result set contains 1 row @@ -276,11 +293,15 @@ void test_transaction_cacheQuery() throws Exception { verify(mockPluginService).isInTransaction(); verify(mockPluginService).getSessionStateService(); verify(mockSessionStateService).getSchema(); + verify(mockSessionStateService).getCatalog(); verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); verify(mockSessionStateService).setSchema("public"); + verify(mockSessionStateService).setCatalog("postgres"); + verify(mockDbMetadata, never()).getUserName(); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); + verify(mockCacheConn).writeToCache(eq("postgres_public_dbuser_select * from T"), any(), eq(300)); verify(mockTotalQueryCounter, times(1)).inc(); verify(mockCacheHitCounter, never()).inc(); verify(mockCacheMissCounter, never()).inc(); @@ -295,14 +316,18 @@ void test_transaction_cacheQuery() throws Exception { @Test void test_transaction_cacheQuery_multiple_query_params() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is cacheable when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(true); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); + when(mockDbMetadata.getUserName()).thenReturn("dbuser"); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockConnection.getSchema()).thenReturn("mysql"); when(mockCallable.call()).thenReturn(mockResult1); // Result set contains 1 row @@ -318,12 +343,15 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { verify(mockPluginService).getCurrentConnection(); verify(mockPluginService).isInTransaction(); verify(mockPluginService).getSessionStateService(); + verify(mockSessionStateService).getCatalog(); verify(mockSessionStateService).getSchema(); verify(mockConnection).getSchema(); - verify(mockSessionStateService).setSchema("public"); + verify(mockConnection).getCatalog(); + verify(mockSessionStateService).setSchema("mysql"); + verify(mockDbMetadata).getUserName(); verify(mockCacheConn, never()).readFromCache(anyString()); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from T"), any(), eq(300)); + verify(mockCacheConn).writeToCache(eq("null_mysql_dbuser_select * from T"), any(), eq(300)); verify(mockTotalQueryCounter, times(1)).inc(); verify(mockCacheHitCounter, never()).inc(); verify(mockCacheMissCounter, never()).inc(); @@ -338,6 +366,8 @@ void test_transaction_cacheQuery_multiple_query_params() throws Exception { @Test void test_transaction_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Query is not cacheable when(mockPluginService.isInTransaction()).thenReturn(true); when(mockCallable.call()).thenReturn(mockResult1); @@ -364,6 +394,8 @@ void test_transaction_noCaching() throws Exception { @Test void test_JdbcCacheBypassCount_malformed_hint() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Setup - not in transaction with malformed cache hint when(mockPluginService.isInTransaction()).thenReturn(false); when(mockCallable.call()).thenReturn(mockResult1); @@ -390,23 +422,20 @@ void test_JdbcCacheBypassCount_malformed_hint() throws Exception { @Test void test_JdbcCacheBypassCount_double_bypass_prevention() throws Exception { + props.setProperty("user", "testuser"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); // Setup - query that meets MULTIPLE bypass conditions when(mockPluginService.isInTransaction()).thenReturn(true); // Bypass condition #1 when(mockCallable.call()).thenReturn(mockResult1); - // Add mocks for caching flow (needed for transaction caching) - when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); - when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); - when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("testuser"); - // Mock result set for caching when(mockResult1.next()).thenReturn(true, false); when(mockResult1.getObject(1)).thenReturn("testdata"); // Query that is BOTH too large AND in transaction - double bypass conditions String largeQueryInTransaction = "/*+ CACHE_PARAM(ttl=300s) */ SELECT * FROM table WHERE data = '" - + RandomStringUtils.randomAlphanumeric(16000) + "'"; // >16KB AND in transaction + + RandomStringUtils.randomAlphanumeric(16384) + "'"; // >16KB AND in transaction // Execute plugin.execute(ResultSet.class, SQLException.class, mockStatement, @@ -432,14 +461,18 @@ void test_JdbcCacheBypassCount_double_bypass_prevention() throws Exception { @Test void test_execute_multipleCacheHits() throws Exception { + props.setProperty("user", "user"); + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); when(mockPluginService.isInTransaction()).thenReturn(false); when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()); when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()).thenReturn(Optional.of("public")); when(mockConnection.getSchema()).thenReturn("public"); - when(mockDbMetadata.getUserName()).thenReturn("user"); - when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(null); + when(mockConnection.getCatalog()).thenReturn(null); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(null); when(mockCallable.call()).thenReturn(mockResult1); // Result set contains 1 row @@ -456,7 +489,7 @@ void test_execute_multipleCacheHits() throws Exception { rs.beforeFirst(); byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); - when(mockCacheConn.readFromCache("public_user_select * from A")).thenReturn(serializedTestResultSet); + when(mockCacheConn.readFromCache("null_public_user_select * from A")).thenReturn(serializedTestResultSet); for (int i = 0; i < 10; i ++) { ResultSet cur_rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, @@ -469,13 +502,16 @@ void test_execute_multipleCacheHits() throws Exception { verify(mockPluginService, times(12)).getCurrentConnection(); verify(mockPluginService, times(11)).isInTransaction(); - verify(mockCacheConn, times(11)).readFromCache("public_user_select * from A"); + verify(mockCacheConn, times(11)).readFromCache("null_public_user_select * from A"); verify(mockPluginService, times(12)).getSessionStateService(); + verify(mockSessionStateService, times(12)).getCatalog(); verify(mockSessionStateService, times(12)).getSchema(); verify(mockConnection).getSchema(); + verify(mockConnection).getCatalog(); verify(mockSessionStateService).setSchema("public"); + verify(mockDbMetadata, never()).getUserName(); verify(mockCallable).call(); - verify(mockCacheConn).writeToCache(eq("public_user_select * from A"), any(), eq(50)); + verify(mockCacheConn).writeToCache(eq("null_public_user_select * from A"), any(), eq(50)); verify(mockTotalQueryCounter, times(11)).inc(); verify(mockCacheMissCounter, times(1)).inc(); verify(mockCacheHitCounter, times(10)).inc(); From e96c4761944432b62df2b7d2eb6ceaa9cae3bcda Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Tue, 7 Oct 2025 10:47:00 -0700 Subject: [PATCH 19/24] Not de-serialize cached responses until the column field gets accessed. --- .../jdbc/plugin/cache/CachedResultSet.java | 82 ++++++++++++++++--- 1 file changed, 72 insertions(+), 10 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index 73b4d2d2c..b4ecbfaf1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -1,11 +1,11 @@ package software.amazon.jdbc.plugin.cache; +import org.checkerframework.checker.nullness.qual.Nullable; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.IOException; import java.io.Reader; -import java.io.Serializable; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.math.BigDecimal; @@ -42,23 +42,44 @@ public class CachedResultSet implements ResultSet { - public static class CachedRow implements Serializable { + public static class CachedRow { private final Object[] rowData; + final byte[] @Nullable [] rawData; public CachedRow(int numColumns) { rowData = new Object[numColumns]; + rawData = new byte[numColumns][]; } - public void put(final int columnIndex, final Object columnValue) throws SQLException { + private void checkColumnIndex(final int columnIndex) throws SQLException { if (columnIndex < 1 || columnIndex > rowData.length) { - throw new SQLException("Invalid Column Index when populating CachedRow: " + columnIndex); + throw new SQLException("Invalid Column Index when operating CachedRow: " + columnIndex); } + } + + public void put(final int columnIndex, final Object columnValue) throws SQLException { + checkColumnIndex(columnIndex); rowData[columnIndex-1] = columnValue; } + public void putRaw(final int columnIndex, final byte[] rawColumnValue) throws SQLException { + checkColumnIndex(columnIndex); + rawData[columnIndex-1] = rawColumnValue; + } + public Object get(final int columnIndex) throws SQLException { - if (columnIndex < 1 || columnIndex > rowData.length) { - throw new SQLException("Invalid Column Index when getting CachedRow value: " + columnIndex); + checkColumnIndex(columnIndex); + // De-serialize the data object from raw bytes if needed. + if (rowData[columnIndex-1] == null && rawData[columnIndex-1] != null) { + try (ByteArrayInputStream bis = new ByteArrayInputStream(rawData[columnIndex - 1]); + ObjectInputStream ois = new ObjectInputStream(bis)) { + rowData[columnIndex - 1] = ois.readObject(); + rawData[columnIndex - 1] = null; + } catch (ClassNotFoundException e) { + throw new SQLException("ClassNotFoundException while de-serializing caching resultSet for column: " + columnIndex, e); + } catch (IOException e) { + throw new SQLException("IOException while de-serializing caching resultSet for column: " + columnIndex, e); + } } return rowData[columnIndex - 1]; } @@ -73,6 +94,12 @@ public Object get(final int columnIndex) throws SQLException { private final HashMap columnNames; private volatile boolean closed; + /** + * Create a CachedResultSet out of the original ResultSet queried from the database. + * @param resultSet The ResultSet queried from the underlying database (not a CachedResultSet). + * @return CachedResultSet that captures the metadata and the rows of the input ResultSet. + * @throws SQLException + */ public CachedResultSet(final ResultSet resultSet) throws SQLException { ResultSetMetaData srcMetadata = resultSet.getMetaData(); final int numColumns = srcMetadata.getColumnCount(); @@ -116,14 +143,28 @@ private CachedResultSet(final CachedResultSetMetaData md, final ArrayList resultRows = new ArrayList<>(numRows); for (int i = 0; i < numRows; i++) { - resultRows.add((CachedRow) ois.readObject()); + // Store the raw bytes for each column object in CachedRow + final CachedRow row = new CachedRow(numColumns); + for(int j = 0; j < numColumns; j++) { + int nextObjSize = ois.readInt(); // The size of the next serialized object in its raw bytes form + byte[] objData = new byte[nextObjSize]; + int lengthRead = 0; + while (lengthRead < nextObjSize) { + int bytesRead = ois.read(objData, lengthRead, nextObjSize-lengthRead); + if (bytesRead == -1) { + throw new SQLException("End of stream reached when reading the data for CachedResultSet"); + } + lengthRead += bytesRead; + } + row.putRaw(j+1, objData); + } + resultRows.add(row); } return new CachedResultSet(metadata, resultRows); } catch (ClassNotFoundException e) { From 57551d4c78faa5a3dec18dab2843ed7e164f5bcb Mon Sep 17 00:00:00 2001 From: Shaopeng Gu Date: Mon, 15 Sep 2025 15:44:01 -0700 Subject: [PATCH 20/24] Caching - Add IAM authentication support for ElastiCache Valkey (PR #4) - Add ElastiCacheIamTokenUtility for token generation - Extend DataRemoteCachePlugin with IAM auth detection - Support serverless and regular ElastiCache endpoints - Implement 15-minute token refresh and 12-hour re-auth cycles - Add cacheIamAuthEnabled configuration property --- .../DatabaseConnectionWithCacheExample.java | 13 + .../jdbc/plugin/cache/CacheConnection.java | 166 +++++++++-- .../jdbc/plugin/cache/CachedSupplier.java | 76 +++++ .../iam/ElastiCacheIamTokenUtility.java | 125 ++++++++ .../plugin/cache/CacheConnectionTest.java | 268 +++++++++++++++++- .../jdbc/plugin/cache/CacheSupplierTest.java | 158 +++++++++++ .../iam/ElastiCacheIamTokenUtilityTest.java | 193 +++++++++++++ 7 files changed, 980 insertions(+), 19 deletions(-) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java index 3441ecd4a..0e73b0330 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -12,6 +12,13 @@ public class DatabaseConnectionWithCacheExample { private static final String DB_CONNECTION_STRING = env.get("DB_CONNECTION_STRING"); private static final String CACHE_RW_SERVER_ADDR = env.get("CACHE_RW_SERVER_ADDR"); private static final String CACHE_RO_SERVER_ADDR = env.get("CACHE_RO_SERVER_ADDR"); + // If the cache server is authenticated with IAM + private static final String CACHE_NAME = env.get("CACHE_NAME"); + // Both IAM and traditional auth uses the same CACHE_USERNAME + private static final String CACHE_USERNAME = env.get("CACHE_USERNAME"); // e.g., "iam-user-01" / "username" + private static final String CACHE_IAM_REGION = env.get("CACHE_IAM_REGION"); // e.g., "us-west-2" + // If the cache server is authenticated with traditional username password + // private static final String CACHE_PASSWORD = env.get("CACHE_PASSWORD"); private static final String USERNAME = env.get("DB_USERNAME"); private static final String PASSWORD = env.get("DB_PASSWORD"); private static final String USE_SSL = env.get("USE_SSL"); @@ -30,6 +37,12 @@ public static void main(String[] args) throws SQLException { properties.setProperty("wrapperPlugins", "dataRemoteCache"); properties.setProperty("cacheEndpointAddrRw", CACHE_RW_SERVER_ADDR); properties.setProperty("cacheEndpointAddrRo", CACHE_RO_SERVER_ADDR); + // If the cache server is authenticated with IAM + properties.setProperty("cacheName", CACHE_NAME); + properties.setProperty("cacheUsername", CACHE_USERNAME); + properties.setProperty("cacheIamRegion", CACHE_IAM_REGION); + // If the cache server is authenticated with traditional username password + // properties.setProperty("cachePassword", PASSWORD); properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false" properties.setProperty("wrapperLogUnclosedConnections", "true"); String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index fca6008eb..14f813ecf 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -1,27 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package software.amazon.jdbc.plugin.cache; import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.RedisCredentials; +import io.lettuce.core.RedisCredentialsProvider; import io.lettuce.core.RedisURI; +import io.lettuce.core.RedisCommandExecutionException; import io.lettuce.core.SetArgs; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; import io.lettuce.core.resource.ClientResources; -import software.amazon.jdbc.AwsWrapperProperty; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.util.Properties; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; import java.util.logging.Logger; import org.apache.commons.pool2.BasePooledObjectFactory; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.commons.pool2.impl.GenericObjectPoolConfig; import org.apache.commons.pool2.impl.DefaultPooledObject; import org.apache.commons.pool2.PooledObject; +import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; import software.amazon.jdbc.util.StringUtils; // Abstraction layer on top of a connection to a remote cache server @@ -31,18 +56,21 @@ public class CacheConnection { private static volatile GenericObjectPool> readConnectionPool; private static volatile GenericObjectPool> writeConnectionPool; private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); - private final String cacheRwServerAddr; // read-write cache server - private final String cacheRoServerAddr; // read-only cache server - private MessageDigest msgHashDigest = null; private static final int DEFAULT_POOL_SIZE = 20; private static final int DEFAULT_POOL_MAX_IDLE = 20; private static final int DEFAULT_POOL_MIN_IDLE = 0; private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100; + private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30; private static final ReentrantLock READ_LOCK = new ReentrantLock(); private static final ReentrantLock WRITE_LOCK = new ReentrantLock(); + private final String cacheRwServerAddr; // read-write cache server + private final String cacheRoServerAddr; // read-only cache server + private final String[] defaultCacheServerHostAndPort; + private MessageDigest msgHashDigest = null; + protected static final AwsWrapperProperty CACHE_RW_ENDPOINT_ADDR = new AwsWrapperProperty( "cacheEndpointAddrRw", @@ -61,7 +89,38 @@ public class CacheConnection { "true", "Whether to use SSL for cache connections."); + protected static final AwsWrapperProperty CACHE_IAM_REGION = + new AwsWrapperProperty( + "cacheIamRegion", + null, + "AWS region for ElastiCache IAM authentication."); + + protected static final AwsWrapperProperty CACHE_USERNAME = + new AwsWrapperProperty( + "cacheUsername", + null, + "Username for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_PASSWORD = + new AwsWrapperProperty( + "cachePassword", + null, + "Password for ElastiCache regular authentication."); + + protected static final AwsWrapperProperty CACHE_NAME = + new AwsWrapperProperty( + "cacheName", + null, + "Explicit cache name for ElastiCache IAM authentication. "); + private final boolean useSSL; + private final boolean iamAuthEnabled; + private final String cacheIamRegion; + private final String cacheUsername; + private final String cacheName; + private final String cachePassword; + private final Properties awsProfileProperties; + private final AwsCredentialsProvider credentialsProvider; static { PropertyDefinition.registerPluginProperties(CacheConnection.class); @@ -71,6 +130,44 @@ public CacheConnection(final Properties properties) { this.cacheRwServerAddr = CACHE_RW_ENDPOINT_ADDR.getString(properties); this.cacheRoServerAddr = CACHE_RO_ENDPOINT_ADDR.getString(properties); this.useSSL = Boolean.parseBoolean(CACHE_USE_SSL.getString(properties)); + this.cacheName = CACHE_NAME.getString(properties); + this.cacheIamRegion = CACHE_IAM_REGION.getString(properties); + this.cacheUsername = CACHE_USERNAME.getString(properties); + this.cachePassword = CACHE_PASSWORD.getString(properties); + this.iamAuthEnabled = !StringUtils.isNullOrEmpty(this.cacheIamRegion); + boolean hasTraditionalAuth = !StringUtils.isNullOrEmpty(this.cachePassword); + // Validate authentication configuration + if (this.iamAuthEnabled && hasTraditionalAuth) { + throw new IllegalArgumentException( + "Cannot specify both IAM authentication (cacheIamRegion) and traditional authentication (cachePassword). Choose one authentication method."); + } + if (this.cacheRwServerAddr == null) { + throw new IllegalArgumentException("Cache endpoint address is required"); + } + this.defaultCacheServerHostAndPort = getHostnameAndPort(this.cacheRwServerAddr); + if (this.iamAuthEnabled) { + if (this.cacheUsername == null || this.defaultCacheServerHostAndPort[0] == null || this.cacheName == null) { + throw new IllegalArgumentException("IAM authentication requires cache name, username, region, and hostname"); + } + } + if (PropertyDefinition.AWS_PROFILE.getString(properties) != null) { + this.awsProfileProperties = new Properties(); + this.awsProfileProperties.setProperty( + PropertyDefinition.AWS_PROFILE.name, + PropertyDefinition.AWS_PROFILE.getString(properties) + ); + } else { + this.awsProfileProperties = null; + } + if (this.iamAuthEnabled) { + // Handle null case + Properties propsToPass = (this.awsProfileProperties != null) + ? this.awsProfileProperties + : new Properties(); + this.credentialsProvider = AwsCredentialsManager.getProvider(null, propsToPass); + } else { + this.credentialsProvider = null; + } } /* Here we check if we need to initialise connection pool for read or write to cache. @@ -113,15 +210,14 @@ private void createConnectionPool(boolean isRead) { if (isRead && !StringUtils.isNullOrEmpty(cacheRoServerAddr)) { serverAddr = cacheRoServerAddr; } - String[] hostnameAndPort = serverAddr.split(":"); - RedisURI redisUriCluster = RedisURI.Builder.redis(hostnameAndPort[0]) - .withPort(Integer.parseInt(hostnameAndPort[1])) - .withSsl(useSSL).withVerifyPeer(false).withLibraryName("aws-sql-jdbc-lettuce").build(); + String[] hostnameAndPort = getHostnameAndPort(serverAddr); + RedisURI redisUriCluster = buildRedisURI(hostnameAndPort[0], Integer.parseInt(hostnameAndPort[1])); RedisClient client = RedisClient.create(resources, redisUriCluster); GenericObjectPool> pool = new GenericObjectPool<>( new BasePooledObjectFactory>() { public StatefulRedisConnection create() { + StatefulRedisConnection connection = client.connect(new ByteArrayCodec()); // In cluster mode, we need to send READONLY command to the server for reading from replica. // Note: we gracefully ignore ERR reply to support non cluster mode. @@ -148,7 +244,6 @@ public PooledObject> wrap(StatefulRedisC } else { writeConnectionPool = pool; } - } catch (Exception e) { String poolType = isRead ? "read" : "write"; String errorMsg = String.format("Failed to create Cache %s connection pool", poolType); @@ -247,13 +342,13 @@ public void writeToCache(String key, byte[] value, int expiry) { private void returnConnectionBackToPool(StatefulRedisConnection connection, boolean isConnectionBroken, boolean isRead) { GenericObjectPool> pool = isRead ? readConnectionPool : writeConnectionPool; if (isConnectionBroken) { - try { - pool.invalidateObject(connection); - } catch (Exception e) { - throw new RuntimeException("Could not invalidate connection for the pool", e); - } + try { + pool.invalidateObject(connection); + } catch (Exception e) { + throw new RuntimeException("Could not invalidate connection for the pool", e); + } } else { - pool.returnObject(connection); + pool.returnObject(connection); } } @@ -263,4 +358,43 @@ protected void setConnectionPools(GenericObjectPool { + // Create a cached token supplier that automatically refreshes tokens every 14.5 minutes + Supplier tokenSupplier = CachedSupplier.memoizeWithExpiration( + () -> { + ElastiCacheIamTokenUtility tokenUtility = new ElastiCacheIamTokenUtility(this.cacheName); + return tokenUtility.generateAuthenticationToken( + this.credentialsProvider, + Region.of(this.cacheIamRegion), + this.defaultCacheServerHostAndPort[0], + Integer.parseInt(this.defaultCacheServerHostAndPort[1]), + this.cacheUsername + ); + }, + TOKEN_CACHE_DURATION, + TimeUnit.SECONDS + ); + // Package the username and token (from cache or freshly generated) into Redis credentials + return Mono.just(RedisCredentials.just(this.cacheUsername, tokenSupplier.get())); + }; + uriBuilder.withAuthentication(credentialsProvider); + } else if (!StringUtils.isNullOrEmpty(this.cachePassword)) { + uriBuilder.withAuthentication(this.cacheUsername, this.cachePassword); + } + return uriBuilder.build(); + } + + private String[] getHostnameAndPort(String serverAddr) { + return serverAddr.split(":"); + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java new file mode 100644 index 000000000..ac3d505e1 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedSupplier.java @@ -0,0 +1,76 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public final class CachedSupplier { + + private CachedSupplier() { + throw new UnsupportedOperationException("Utility class should not be instantiated"); + } + + public static Supplier memoizeWithExpiration( + Supplier delegate, long duration, TimeUnit unit) { + + Objects.requireNonNull(delegate, "delegate Supplier must not be null"); + Objects.requireNonNull(unit, "TimeUnit must not be null"); + if (duration <= 0) { + throw new IllegalArgumentException("duration must be > 0"); + } + + return new ExpiringMemoizingSupplier<>(delegate, duration, unit); + } + + private static final class ExpiringMemoizingSupplier implements Supplier { + + private final Supplier delegate; + private final long durationNanos; + private final ReentrantLock lock = new ReentrantLock(); + + private volatile T value; + private volatile long expirationNanos; // 0 means not yet initialized + + ExpiringMemoizingSupplier(Supplier delegate, long duration, TimeUnit unit) { + this.delegate = delegate; + this.durationNanos = unit.toNanos(duration); + } + + @Override + public T get() { + long now = System.nanoTime(); + + // Check if value is expired or uninitialized + if (expirationNanos == 0 || now - expirationNanos >= 0) { + lock.lock(); + try { + if (expirationNanos == 0 || now - expirationNanos >= 0) { + value = delegate.get(); + long next = now + durationNanos; + expirationNanos = (next == 0) ? 1 : next; // avoid 0 sentinel + } + } finally { + lock.unlock(); + } + } + return value; + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java new file mode 100644 index 000000000..12d96825a --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtility.java @@ -0,0 +1,125 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Objects; +import java.util.logging.Logger; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.CredentialUtils; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.StringUtils; + +public class ElastiCacheIamTokenUtility implements IamTokenUtility { + + private static final Logger LOGGER = Logger.getLogger(ElastiCacheIamTokenUtility.class.getName()); + private static final String PARAM_ACTION = "Action"; + private static final String PARAM_USER = "User"; + private static final String ACTION_NAME = "connect"; + private static final String PARAM_RESOURCE_TYPE = "ResourceType"; + private static final String RESOURCE_TYPE_SERVERLESS_CACHE = "ServerlessCache"; + private static final String SERVICE_NAME = "elasticache"; + private static final String PROTOCOL = "http"; + private static final Duration EXPIRATION_DURATION = Duration.ofSeconds(15 * 60 - 30); + public static final String SERVERLESS_CACHE_IDENTIFIER = ".serverless."; + + private final Clock clock; + private String cacheName = null; + private final Aws4Signer signer; + + public ElastiCacheIamTokenUtility(String cacheName) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.systemUTC(); + this.signer = Aws4Signer.create(); + } + + // For testing only + public ElastiCacheIamTokenUtility(String cacheName, Instant fixedInstant, Aws4Signer signer) { + this.cacheName = Objects.requireNonNull(cacheName, "cacheName cannot be null"); + this.clock = Clock.fixed(fixedInstant, ZoneId.of("UTC")); + this.signer = signer; + } + + @Override + public String generateAuthenticationToken( + final @NonNull AwsCredentialsProvider credentialsProvider, + final @NonNull Region region, + final @NonNull String hostname, + final int port, + final @NonNull String username) { + + boolean isServerless = isServerlessCache(hostname); + if (this.cacheName == null) { + throw new IllegalArgumentException("Cache name cannot be null for cache with IAM authentication"); + } + + SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol(PROTOCOL) // ElastiCache uses http, not https + .host(this.cacheName) + .encodedPath("/") + .putRawQueryParameter(PARAM_ACTION, ACTION_NAME) + .putRawQueryParameter(PARAM_USER, username); + + if (isServerless) { + requestBuilder.putRawQueryParameter(PARAM_RESOURCE_TYPE, RESOURCE_TYPE_SERVERLESS_CACHE); + } + + final SdkHttpFullRequest httpRequest = requestBuilder.build(); + + final Instant expirationTime = Instant.now(this.clock).plus(EXPIRATION_DURATION); + + final AwsCredentials credentials = CredentialUtils.toCredentials( + CompletableFutureUtils.joinLikeSync(credentialsProvider.resolveIdentity())); + + final Aws4PresignerParams presignRequest = Aws4PresignerParams.builder() + .signingClockOverride(this.clock) + .expirationTime(expirationTime) + .awsCredentials(credentials) + .signingName(SERVICE_NAME) + .signingRegion(region) + .build(); + + final SdkHttpFullRequest fullRequest = this.signer.presign(httpRequest, presignRequest); + final String signedUrl = fullRequest.getUri().toString(); + + // Format should be: + // Regular: /?Action=connect&User=&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Serverless: /?Action=connect&User=&ResourceType=ServerlessCache&X-Amz-Security-Token=...&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Date=...&X-Amz-SignedHeaders=host&X-Amz-Expires=870&X-Amz-Credential=...&X-Amz-Signature=... + // Note: This must be the real ElastiCache hostname, not proxy or tunnels + final String result = StringUtils.replacePrefixIgnoreCase(signedUrl, "http://", ""); + LOGGER.finest(() -> "Generated ElastiCache authentication token with expiration of " + expirationTime); + return result; + } + + private boolean isServerlessCache(String hostname) { + if (hostname == null) { + throw new IllegalArgumentException("Hostname cannot be null"); + } + return hostname.contains(SERVERLESS_CACHE_IDENTIFIER); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java index 51f7338e1..199865da0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheConnectionTest.java @@ -1,9 +1,23 @@ -package software.amazon.jdbc.plugin.cache; +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; +package software.amazon.jdbc.plugin.cache; import io.lettuce.core.RedisFuture; +import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.api.sync.RedisCommands; @@ -12,11 +26,17 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; +import org.mockito.MockedConstruction; import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.plugin.iam.ElastiCacheIamTokenUtility; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.util.function.BiConsumer; import java.util.Properties; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; public class CacheConnectionTest { @@ -45,6 +65,242 @@ void cleanUp() throws Exception { closeable.close(); } + @Test + void testIamAuth_PropertyExtraction() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("us-west-2", getField(connection, "cacheIamRegion")); + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testIamAuth_PropertyExtractionTraditional() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "password"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Verify all IAM fields are set correctly + assertEquals("myuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("password", getField(connection, "cachePassword")); + } + + @Test + void testIamAuthEnabled_WhenRegionProvided() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + + CacheConnection connection = new CacheConnection(props); + + // Use reflection to verify iamAuthEnabled is true + Field field = CacheConnection.class.getDeclaredField("iamAuthEnabled"); + field.setAccessible(true); + assertTrue((boolean) field.get(connection)); + // Verify all IAM fields are set correctly + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("my-cache", getField(connection, "cacheName")); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheName() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); + props.setProperty("cacheUsername", "myuser"); + // Missing cacheName property + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("IAM authentication requires cache name, username, region, and hostname")); + } + + @Test + void testTraditionalAuth_WhenNoIamRegion() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + + assertFalse((boolean) getField(connection, "iamAuthEnabled")); + assertNull(getField(connection, "credentialsProvider")); + assertEquals("user", getField(connection, "cacheUsername")); + assertEquals("pass", getField(connection, "cachePassword")); + } + + @Test + void testConstructor_NoRwAddress() { + Properties props = new Properties(); + props.setProperty("wrapperPlugins", "dataRemoteCache"); + props.setProperty("cacheEndpointAddrRo", "localhost:6379"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_IamAuthEnabled_MissingCacheUsername() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + + assertThrows(IllegalArgumentException.class, () -> new CacheConnection(props)); + } + + @Test + void testConstructor_ConflictingAuthenticationMethods() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-west-2"); // IAM auth + props.setProperty("cacheUsername", "myuser"); + props.setProperty("cachePassword", "mypassword"); // Traditional auth + props.setProperty("cacheName", "my-cache"); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new CacheConnection(props) + ); + + assertTrue(exception.getMessage().contains("Cannot specify both IAM authentication")); + } + + @Test + void testAwsCredentialsProvider_WithProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + props.setProperty("awsProfile", "test-profile"); + + CacheConnection connection = new CacheConnection(props); + + // Verify the awsProfileProperties field contains the correct profile + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertEquals("test-profile", awsProfileProps.getProperty("awsProfile")); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("test.cache.amazonaws.com:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testAwsCredentialsProvider_WithoutProfile() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "my-cache"); + // No awsProfile property + + CacheConnection connection = new CacheConnection(props); + + // Verify awsProfileProperties is not empty when no profile specified + Properties awsProfileProps = (Properties) getField(connection, "awsProfileProperties"); + assertNull(awsProfileProps); + + assertEquals("my-cache", getField(connection, "cacheName")); + assertEquals("testuser", getField(connection, "cacheUsername")); + assertEquals("us-east-1", getField(connection, "cacheIamRegion")); + assertEquals("test.cache.amazonaws.com:6379", getField(connection, "cacheRwServerAddr")); + } + + @Test + void testBuildRedisURI_IamAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "test-cache.cache.amazonaws.com:6379"); + props.setProperty("cacheIamRegion", "us-east-1"); + props.setProperty("cacheUsername", "testuser"); + props.setProperty("cacheName", "test-cache"); + + try (MockedConstruction mockedTokenUtility = mockConstruction(ElastiCacheIamTokenUtility.class)) { + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("test-cache.cache.amazonaws.com", 6379); + + // Verify URI properties + assertNotNull(uri); + assertEquals("test-cache.cache.amazonaws.com", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); + + // Trigger the credentials provider to create the token utility + uri.getCredentialsProvider().resolveCredentials().block(); + + // Verify URI properties + assertNotNull(uri); + assertEquals("test-cache.cache.amazonaws.com", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertTrue(uri.isSsl()); + assertNotNull(uri.getCredentialsProvider()); // IAM credentials provider set + + // Verify ElastiCacheIamTokenUtility was constructed with correct parameters + // Verify token utility construction + assertEquals(1, mockedTokenUtility.constructed().size()); + ElastiCacheIamTokenUtility tokenUtility = mockedTokenUtility.constructed().get(0); + verify(tokenUtility).generateAuthenticationToken( + any(AwsCredentialsProvider.class), + eq(Region.US_EAST_1), + eq("test-cache.cache.amazonaws.com"), + eq(6379), + eq("testuser") + ); + } + } + + @Test + void testBuildRedisURI_TraditionalAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheUsername", "user"); + props.setProperty("cachePassword", "pass"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertEquals("user", uri.getUsername()); + assertEquals("pass", new String(uri.getPassword())); + } + + @Test + void testBuildRedisURI_NoAuth() { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + RedisURI uri = connection.buildRedisURI("localhost", 6379); + + assertNotNull(uri); + assertEquals("localhost", uri.getHost()); + assertEquals(6379, uri.getPort()); + assertNull(uri.getUsername()); + assertNull(uri.getPassword()); + } + @Test void test_writeToCache() throws Exception { String key = "myQueryKey"; @@ -107,4 +363,10 @@ void test_readFromCacheException() throws Exception { verify(mockSyncCommands).get(any()); verify(mockReadConnPool).invalidateObject(mockConnection); } + + private Object getField(Object obj, String fieldName) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(obj); + } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java new file mode 100644 index 000000000..458ae96cc --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CacheSupplierTest.java @@ -0,0 +1,158 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.cache; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.Test; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public class CacheSupplierTest { + + @Test + void testMemoizeWithExpiration_ValidParameters() { + Supplier delegate = () -> "test-value"; + + Supplier cached = CachedSupplier.memoizeWithExpiration(delegate, 1, TimeUnit.SECONDS); + + assertNotNull(cached); + assertEquals("test-value", cached.get()); + } + + @Test + void testMemoizeWithExpiration_NullDelegate() { + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(null, 1, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NullTimeUnit() { + Supplier delegate = () -> "test"; + + assertThrows(NullPointerException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 1, null)); + } + + @Test + void testMemoizeWithExpiration_ZeroDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, 0, TimeUnit.SECONDS)); + } + + @Test + void testMemoizeWithExpiration_NegativeDuration() { + Supplier delegate = () -> "test"; + + assertThrows(IllegalArgumentException.class, () -> + CachedSupplier.memoizeWithExpiration(delegate, -1, TimeUnit.SECONDS)); + } + + @Test + void testCaching_DelegateCalledOnce() { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("cached-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 1, TimeUnit.SECONDS); + + // Call multiple times quickly + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + assertEquals("cached-value", cached.get()); + + // Delegate should only be called once due to caching + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpiration_DelegateCalledAgainAfterExpiry() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("value1", "value2"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 50, TimeUnit.MILLISECONDS); + + // First call + assertEquals("value1", cached.get()); + verify(mockDelegate, times(1)).get(); + + // Wait for expiration + Thread.sleep(100); + + // Second call after expiration + assertEquals("value2", cached.get()); + verify(mockDelegate, times(2)).get(); + } + + @Test + void testConcurrentAccess() throws InterruptedException { + Supplier mockDelegate = mock(Supplier.class); + when(mockDelegate.get()).thenReturn("concurrent-value"); + + Supplier cached = CachedSupplier.memoizeWithExpiration(mockDelegate, 5, TimeUnit.SECONDS); + + // Simulate concurrent access + Thread[] threads = new Thread[10]; + String[] results = new String[10]; + + for (int i = 0; i < 10; i++) { + final int index = i; + threads[i] = new Thread(() -> results[index] = cached.get()); + threads[i].start(); + } + + // Wait for all threads + for (Thread thread : threads) { + thread.join(); + } + + // All should get the same cached value + for (String result : results) { + assertEquals("concurrent-value", result); + } + + // Delegate should only be called once despite concurrent access + verify(mockDelegate, times(1)).get(); + } + + @Test + void testExpirationNanos_EdgeCase() { + Supplier timeSupplier = () -> System.nanoTime(); + + Supplier cached = CachedSupplier.memoizeWithExpiration(timeSupplier, 1, TimeUnit.NANOSECONDS); + + Long first = cached.get(); + Long second = cached.get(); + + // Due to very short expiration, second call might get different value + assertNotNull(first); + assertNotNull(second); + } + + @Test + void testPrivateConstructor() { + // Verify utility class has private constructor + assertThrows(Exception.class, () -> { + java.lang.reflect.Constructor constructor = + CachedSupplier.class.getDeclaredConstructor(); + constructor.setAccessible(true); + constructor.newInstance(); + }); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java new file mode 100644 index 000000000..4eeb24e19 --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/ElastiCacheIamTokenUtilityTest.java @@ -0,0 +1,193 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.signer.Aws4Signer; +import software.amazon.awssdk.auth.signer.params.Aws4PresignerParams; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.regions.Region; + +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; + +public class ElastiCacheIamTokenUtilityTest { + @Mock private AwsCredentialsProvider mockCredentialsProvider; + @Mock private AwsCredentials mockCredentials; + @Mock private Aws4Signer mockSigner; + @Mock private SdkHttpFullRequest mockSignedRequest; + + private AutoCloseable closeable; + private ElastiCacheIamTokenUtility tokenUtility; + private final Instant fixedInstant = Instant.parse("2025-01-01T12:00:00Z"); + + @BeforeEach + void setUp() { + closeable = MockitoAnnotations.openMocks(this); + } + + @AfterEach + void tearDown() throws Exception { + closeable.close(); + } + + @Test + void testConstructor_WithCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache"); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_WithCacheNameAndFixedInstant() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + assertNotNull(tokenUtility); + } + + @Test + void testConstructor_NullCacheName() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null)); + } + + @Test + void testConstructor_NullCacheNameWithInstant() { + assertThrows(NullPointerException.class, () -> + new ElastiCacheIamTokenUtility(null, fixedInstant, mockSigner)); + } + + @Test + void testGenerateAuthenticationToken_RegularCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertFalse(request.rawQueryParameters().containsKey("ResourceType")); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_ServerlessCache() { + // Setup mock credentials provider to return mockCredentials + when(mockCredentialsProvider.resolveIdentity()) + .thenReturn((CompletableFuture) CompletableFuture.completedFuture(mockCredentials)); + + // Add custom presign behavior to capture and verify arguments + when(mockSigner.presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class))) + .thenAnswer(invocation -> { + SdkHttpFullRequest request = invocation.getArgument(0); + Aws4PresignerParams presignParams = invocation.getArgument(1); + + // Verify SdkHttpFullRequest + assertEquals("test-cache", request.host()); + assertEquals("/", request.encodedPath()); + assertEquals("connect", request.rawQueryParameters().get("Action").get(0)); + assertEquals("testuser", request.rawQueryParameters().get("User").get(0)); + assertEquals("ServerlessCache", request.rawQueryParameters().get("ResourceType").get(0)); + + // Verify Aws4PresignerParams + assertEquals("elasticache", presignParams.signingName()); + assertEquals(Region.US_EAST_1, presignParams.signingRegion()); + assertEquals(mockCredentials, presignParams.awsCredentials()); + + Instant expectedExpiration = fixedInstant.plus(Duration.ofSeconds(15 * 60 - 30)); + assertEquals(expectedExpiration, presignParams.expirationTime().get()); + assertEquals(fixedInstant, presignParams.signingClockOverride().get().instant()); + + return mockSignedRequest; + }); + + when(mockSignedRequest.getUri()).thenReturn(java.net.URI.create("http://test-cache.serverless.cache.amazonaws.com/result")); + + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + String token = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-cache.serverless.cache.amazonaws.com", 6379, "testuser"); + + assertEquals("test-cache.serverless.cache.amazonaws.com/result", token); + verify(mockSigner).presign(any(SdkHttpFullRequest.class), any(Aws4PresignerParams.class)); + } + + @Test + void testGenerateAuthenticationToken_NullCacheName() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + // Use reflection to set cacheName to null to test the validation + try { + java.lang.reflect.Field field = ElastiCacheIamTokenUtility.class.getDeclaredField("cacheName"); + field.setAccessible(true); + field.set(tokenUtility, null); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, "test-host", 6379, "testuser")); + } catch (Exception e) { + fail("Reflection failed: " + e.getMessage()); + } + } + + @Test + void testGenerateAuthenticationToken_NullHostname() { + tokenUtility = new ElastiCacheIamTokenUtility("test-cache", fixedInstant, mockSigner); + + assertThrows(IllegalArgumentException.class, () -> + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, Region.US_EAST_1, null, 6379, "testuser")); + } +} From b7cd939b0573ae220b7cc81cd3938ebe33feb6c6 Mon Sep 17 00:00:00 2001 From: narasimhaarun oruganti Date: Wed, 22 Oct 2025 16:06:59 -0700 Subject: [PATCH 21/24] Caching: Allow user to configure cache connection timeout and connection pool --- .../DatabaseConnectionWithCacheExample.java | 8 +- .../jdbc/plugin/cache/CacheConnection.java | 45 +++++++-- .../plugin/cache/CacheConnectionTest.java | 94 +++++++++++++++++++ 3 files changed, 136 insertions(+), 11 deletions(-) diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java index 0e73b0330..c199beb1e 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/DatabaseConnectionWithCacheExample.java @@ -17,13 +17,15 @@ public class DatabaseConnectionWithCacheExample { // Both IAM and traditional auth uses the same CACHE_USERNAME private static final String CACHE_USERNAME = env.get("CACHE_USERNAME"); // e.g., "iam-user-01" / "username" private static final String CACHE_IAM_REGION = env.get("CACHE_IAM_REGION"); // e.g., "us-west-2" + private static final String CACHE_USE_SSL = env.get("CACHE_USE_SSL"); // If the cache server is authenticated with traditional username password // private static final String CACHE_PASSWORD = env.get("CACHE_PASSWORD"); private static final String USERNAME = env.get("DB_USERNAME"); private static final String PASSWORD = env.get("DB_PASSWORD"); - private static final String USE_SSL = env.get("USE_SSL"); private static final int THREAD_COUNT = 8; //Use 8 Threads private static final long TEST_DURATION_MS = 16000; //Test duration for 16 seconds + private static final String CACHE_CONNECTION_TIMEOUT = env.get("CACHE_CONNECTION_TIMEOUT"); //Set connection timeout in milliseconds + private static final String CACHE_CONNECTION_POOL_SIZE = env.get("CACHE_CONNECTION_POOL_SIZE"); //Set connection pool size public static void main(String[] args) throws SQLException { final Properties properties = new Properties(); @@ -43,8 +45,10 @@ public static void main(String[] args) throws SQLException { properties.setProperty("cacheIamRegion", CACHE_IAM_REGION); // If the cache server is authenticated with traditional username password // properties.setProperty("cachePassword", PASSWORD); - properties.setProperty("cacheUseSSL", USE_SSL); // "true" or "false" + properties.setProperty("cacheUseSSL", CACHE_USE_SSL); // "true" or "false" properties.setProperty("wrapperLogUnclosedConnections", "true"); + properties.setProperty("cacheConnectionTimeout", CACHE_CONNECTION_TIMEOUT); + properties.setProperty("cacheConnectionPoolSize", CACHE_CONNECTION_POOL_SIZE); String queryStr = "/*+ CACHE_PARAM(ttl=300s) */ select * from cinemas"; // Create threads for concurrent connection testing diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java index 14f813ecf..1c2d1aac7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CacheConnection.java @@ -52,14 +52,9 @@ // Abstraction layer on top of a connection to a remote cache server public class CacheConnection { private static final Logger LOGGER = Logger.getLogger(CacheConnection.class.getName()); - // Adding support for read and write connection pools to the remote cache server - private static volatile GenericObjectPool> readConnectionPool; - private static volatile GenericObjectPool> writeConnectionPool; - private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); - private static final int DEFAULT_POOL_SIZE = 20; - private static final int DEFAULT_POOL_MAX_IDLE = 20; private static final int DEFAULT_POOL_MIN_IDLE = 0; + private static final int DEFAULT_MAX_POOL_SIZE = 200; private static final long DEFAULT_MAX_BORROW_WAIT_MS = 100; private static final long TOKEN_CACHE_DURATION = 15 * 60 - 30; @@ -113,12 +108,31 @@ public class CacheConnection { null, "Explicit cache name for ElastiCache IAM authentication. "); + protected static final AwsWrapperProperty CACHE_CONNECTION_TIMEOUT = + new AwsWrapperProperty( + "cacheConnectionTimeout", + "2000", + "Cache connection request timeout duration in milliseconds."); + + protected static final AwsWrapperProperty CACHE_CONNECTION_POOL_SIZE = + new AwsWrapperProperty( + "cacheConnectionPoolSize", + "20", + "Cache connection pool size."); + + // Adding support for read and write connection pools to the remote cache server + private static volatile GenericObjectPool> readConnectionPool; + private static volatile GenericObjectPool> writeConnectionPool; + private static final GenericObjectPoolConfig> poolConfig = createPoolConfig(); + private final boolean useSSL; private final boolean iamAuthEnabled; private final String cacheIamRegion; private final String cacheUsername; private final String cacheName; private final String cachePassword; + private final Duration cacheConnectionTimeout; + private final int cacheConnectionPoolSize; private final Properties awsProfileProperties; private final AwsCredentialsProvider credentialsProvider; @@ -134,6 +148,15 @@ public CacheConnection(final Properties properties) { this.cacheIamRegion = CACHE_IAM_REGION.getString(properties); this.cacheUsername = CACHE_USERNAME.getString(properties); this.cachePassword = CACHE_PASSWORD.getString(properties); + this.cacheConnectionTimeout = Duration.ofMillis(CACHE_CONNECTION_TIMEOUT.getInteger(properties)); + this.cacheConnectionPoolSize = CACHE_CONNECTION_POOL_SIZE.getInteger(properties); + if (this.cacheConnectionPoolSize <= 0 || this.cacheConnectionPoolSize > DEFAULT_MAX_POOL_SIZE) { + throw new IllegalArgumentException( + "Cache connection pool size must be within valid range: 1-" + DEFAULT_MAX_POOL_SIZE + ", but was: " + this.cacheConnectionPoolSize); + } + // Update the static poolConfig with user values + poolConfig.setMaxTotal(this.cacheConnectionPoolSize); + poolConfig.setMaxIdle(this.cacheConnectionPoolSize); this.iamAuthEnabled = !StringUtils.isNullOrEmpty(this.cacheIamRegion); boolean hasTraditionalAuth = !StringUtils.isNullOrEmpty(this.cachePassword); // Validate authentication configuration @@ -254,8 +277,6 @@ public PooledObject> wrap(StatefulRedisC private static GenericObjectPoolConfig> createPoolConfig() { GenericObjectPoolConfig> poolConfig = new GenericObjectPoolConfig<>(); - poolConfig.setMaxTotal(DEFAULT_POOL_SIZE); - poolConfig.setMaxIdle(DEFAULT_POOL_MAX_IDLE); poolConfig.setMinIdle(DEFAULT_POOL_MIN_IDLE); poolConfig.setMaxWait(Duration.ofMillis(DEFAULT_MAX_BORROW_WAIT_MS)); return poolConfig; @@ -359,12 +380,18 @@ protected void setConnectionPools(GenericObjectPool> readPool = getStaticPool("readConnectionPool"); + GenericObjectPool> writePool = getStaticPool("writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(20, readPool.getMaxTotal()); + assertEquals(20, readPool.getMaxIdle()); + assertEquals(20, writePool.getMaxTotal()); + assertEquals(20, writePool.getMaxIdle()); + assertNotEquals(8, readPool.getMaxTotal()); // making sure it does not set the default values of Generic pool + assertNotEquals(8, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionPoolSize_Initialization() throws Exception { + clearStaticPools(); + + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheEndpointAddrRo", "localhost:6380"); + props.setProperty("cacheConnectionPoolSize", "15"); + + CacheConnection connection = new CacheConnection(props); + + // Create real pools (no network until borrow) + connection.triggerPoolInit(true); + connection.triggerPoolInit(false); + + GenericObjectPool> readPool = getStaticPool("readConnectionPool"); + GenericObjectPool> writePool = getStaticPool("writeConnectionPool"); + + assertNotNull(readPool, "read pool should be created"); + assertNotNull(writePool, "write pool should be created"); + + assertEquals(15, readPool.getMaxTotal()); + assertEquals(15, readPool.getMaxIdle()); + assertEquals(15, writePool.getMaxTotal()); + assertEquals(15, writePool.getMaxIdle()); + } + + @Test + void test_cacheConnectionTimeout_Initialization() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + props.setProperty("cacheConnectionTimeout", "5000"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(5000), timeout); + } + + @Test + void test_cacheConnectionTimeout_default() throws Exception { + Properties props = new Properties(); + props.setProperty("cacheEndpointAddrRw", "localhost:6379"); + + CacheConnection connection = new CacheConnection(props); + Duration timeout = (Duration) getField(connection, "cacheConnectionTimeout"); + assertEquals(Duration.ofMillis(2000), timeout, "default should be 2000 ms"); + } + + @SuppressWarnings("unchecked") + private static GenericObjectPool> getStaticPool(String field) throws Exception { + Field f = CacheConnection.class.getDeclaredField(field); + f.setAccessible(true); + return (GenericObjectPool>) f.get(null); + } + + private static void clearStaticPools() throws Exception { + for (String fieldName : new String[]{"readConnectionPool", "writeConnectionPool"}) { + Field f = CacheConnection.class.getDeclaredField(fieldName); + f.setAccessible(true); + f.set(null, null); + } + } + private Object getField(Object obj, String fieldName) throws Exception { Field field = obj.getClass().getDeclaredField(fieldName); field.setAccessible(true); From c12d3ed3adaec8fa5a4f52e662c217092bb8cc59 Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Thu, 30 Oct 2025 10:55:51 -0700 Subject: [PATCH 22/24] Caching - properly bypass the cache for queries that return non ResultSet objects --- .../amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index cf04ccdd2..476c129d5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -252,6 +252,12 @@ public T execute( final JdbcCallable jdbcMethodFunc, final Object[] jdbcMethodArgs) throws E { + if (resultClass != ResultSet.class) { + return jdbcMethodFunc.call(); + } + + incrCounter(totalQueryCounter); + ResultSet result; boolean needToCache = false; final String sql = getQuery(jdbcMethodArgs); @@ -273,8 +279,6 @@ public T execute( } } - incrCounter(totalQueryCounter); - // Query result can be served from the cache if it has a configured TTL value, and it is // not executed in a transaction as a transaction typically need to return consistent results. if (!isInTransaction && (configuredQueryTtl != null)) { From 65433cf7039591d9df7b925db472b0a4b202c98c Mon Sep 17 00:00:00 2001 From: Qu Chen Date: Mon, 24 Nov 2025 17:33:19 -0800 Subject: [PATCH 23/24] Caching - handle empty query and prepared statement to support Hibernate integration. --- .../plugin/cache/DataRemoteCachePlugin.java | 18 ++- .../cache/DataRemoteCachePluginTest.java | 132 +++++++++++++++++- 2 files changed, 145 insertions(+), 5 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java index 476c129d5..44f0a26af 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePlugin.java @@ -18,6 +18,7 @@ import java.sql.Connection; import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Arrays; @@ -43,7 +44,7 @@ public class DataRemoteCachePlugin extends AbstractConnectionPlugin { private static final Logger LOGGER = Logger.getLogger(DataRemoteCachePlugin.class.getName()); - private static final String QUERY_HINT_START_PATTERN = "/*+"; + private static final String QUERY_HINT_START_PATTERN = "/*"; private static final String QUERY_HINT_END_PATTERN = "*/"; private static final String CACHE_PARAM_PATTERN = "CACHE_PARAM("; private static final String TELEMETRY_CACHE_LOOKUP = "jdbc-cache-lookup"; @@ -260,7 +261,7 @@ public T execute( ResultSet result; boolean needToCache = false; - final String sql = getQuery(jdbcMethodArgs); + final String sql = getQuery(methodInvokeOn, jdbcMethodArgs); TelemetryContext cacheContext = null; TelemetryContext dbContext = null; @@ -271,7 +272,7 @@ public T execute( int endOfQueryHint = 0; Integer configuredQueryTtl = null; // Queries longer than 16KB is not cacheable - if ((sql.length() < maxCacheableQuerySize) && sql.startsWith(QUERY_HINT_START_PATTERN)) { + if (!StringUtils.isNullOrEmpty(sql) && (sql.length() < maxCacheableQuerySize) && sql.contains(QUERY_HINT_START_PATTERN)) { endOfQueryHint = sql.indexOf(QUERY_HINT_END_PATTERN); if (endOfQueryHint > 0) { configuredQueryTtl = getTtlForQuery(sql.substring(QUERY_HINT_START_PATTERN.length(), endOfQueryHint).trim()); @@ -355,11 +356,20 @@ private void incrCounter(TelemetryCounter counter) { counter.inc(); } - protected String getQuery(final Object[] jdbcMethodArgs) { + protected String getQuery(final Object methodInvokeOn, final Object[] jdbcMethodArgs) { // Get query from method argument if (jdbcMethodArgs != null && jdbcMethodArgs.length > 0 && jdbcMethodArgs[0] != null) { return jdbcMethodArgs[0].toString().trim(); } + + // If the query is not in the method arguments, check for prepared statement query. Get the query + // string from the prepared statement. The exact query string is dependent on the underlying driver. + if (methodInvokeOn instanceof PreparedStatement) { + // For postgres, this gives the raw query itself. i.e. "select * from T where A = 1". + // For MySQL, this gives "com.mysql.cj.jdbc.ClientPreparedStatement: select * from T where A = 1" + // For MariaDB, this gives "ClientPreparedStatement{sql:'select * from T where A=1', parameters:[]}" + return methodInvokeOn.toString(); + } return null; } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java index 54b466e50..54e3802a2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/DataRemoteCachePluginTest.java @@ -43,6 +43,7 @@ public class DataRemoteCachePluginTest { @Mock TelemetryContext mockTelemetryContext; @Mock ResultSet mockResult1; @Mock Statement mockStatement; + @Mock PreparedStatement mockPreparedStatement; @Mock ResultSetMetaData mockMetaData; @Mock Connection mockConnection; @Mock SessionStateService mockSessionStateService; @@ -170,6 +171,67 @@ void test_execute_noCaching() throws Exception { verify(mockTelemetryContext).closeContext(); } + @Test + void test_execute_emptyQuery_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, + methodName, mockCallable, new String[]{}); + + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1"); + compareResults(mockResult1, rs); + verify(mockPluginService).isInTransaction(); + verify(mockCallable).call(); + verify(mockTotalQueryCounter, times(1)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(1)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + verify(mockTelemetryContext).closeContext(); + } + + @Test + void test_execute_emptyPreparedStatement_noCaching() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is not cacheable + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockPreparedStatement.toString()).thenReturn("", null); + when(mockCallable.call()).thenReturn(mockResult1); + + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + // Mock result set containing 1 row + when(mockResult1.next()).thenReturn(true, true, false, true, true, false); + when(mockResult1.getObject(1)).thenReturn("bar1", "bar1", "bar1", "bar1"); + compareResults(mockResult1, rs); + + rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + // Mock result set containing 1 row + compareResults(mockResult1, rs); + + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCallable, times(2)).call(); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheHitCounter, never()).inc(); + verify(mockCacheBypassCounter, times(2)).inc(); + verify(mockCacheMissCounter, never()).inc(); + // Verify TelemetryContext behavior for no-caching scenario + verify(mockTelemetryFactory, times(2)).openTelemetryContext("jdbc-database-query", TelemetryTraceLevel.TOP_LEVEL); + verify(mockTelemetryFactory, never()).openTelemetryContext(eq("jdbc-cache-lookup"), any()); + //verify(mockPreparedStatement, times(2)).toString(); + verify(mockTelemetryContext, times(2)).closeContext(); + } + @Test void test_execute_noCachingLongQuery() throws Exception { plugin = new DataRemoteCachePlugin(mockPluginService, props); @@ -179,7 +241,7 @@ void test_execute_noCachingLongQuery() throws Exception { when(mockCallable.call()).thenReturn(mockResult1); ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockStatement, - methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T" + RandomStringUtils.randomAlphanumeric(15990)}); + methodName, mockCallable, new String[]{"/* CACHE_PARAM(ttl=20s) */ select * from T " + RandomStringUtils.randomAlphanumeric(16350)}); // Mock result set containing 1 row when(mockResult1.next()).thenReturn(true, true, false); @@ -262,6 +324,74 @@ void test_execute_cachingMissAndHit() throws Exception { verify(mockTelemetryContext, times(3)).closeContext(); } + @Test + void test_cachingMissAndHit_preparedStatement() throws Exception { + plugin = new DataRemoteCachePlugin(mockPluginService, props); + plugin.setCacheConnection(mockCacheConn); + // Query is a cache miss + when(mockPluginService.getCurrentConnection()).thenReturn(mockConnection); + when(mockPluginService.isInTransaction()).thenReturn(false); + when(mockConnection.getMetaData()).thenReturn(mockDbMetadata); + when(mockPluginService.getSessionStateService()).thenReturn(mockSessionStateService); + when(mockSessionStateService.getCatalog()).thenReturn(Optional.empty()).thenReturn(Optional.of("mysql")); + when(mockSessionStateService.getSchema()).thenReturn(Optional.empty()); + when(mockConnection.getCatalog()).thenReturn("mysql"); + when(mockConnection.getSchema()).thenReturn(null); + when(mockDbMetadata.getUserName()).thenReturn("user1@1.1.1.1"); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(null); + when(mockCallable.call()).thenReturn(mockResult1); + + // Result set contains 1 row + when(mockResult1.next()).thenReturn(true, false); + when(mockResult1.getObject(1)).thenReturn("bar1"); + when(mockPreparedStatement.toString()).thenReturn("/* CACHE_PARAM(ttl=50s) */ select * from A"); + + // Now query is a cache hit + ResultSet rs = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + + // Cached result set contains 1 row + assertTrue(rs.next()); + assertEquals("bar1", rs.getString("fooName")); + assertFalse(rs.next()); + + rs.beforeFirst(); + byte[] serializedTestResultSet = ((CachedResultSet)rs).serializeIntoByteArray(); + when(mockCacheConn.readFromCache("mysql_null_user1_select * from A")).thenReturn(serializedTestResultSet); + + ResultSet rs2 = plugin.execute(ResultSet.class, SQLException.class, mockPreparedStatement, + methodName, mockCallable, new String[]{}); + + assertTrue(rs2.next()); + assertEquals("bar1", rs2.getString("fooName")); + assertFalse(rs2.next()); + verify(mockPluginService, times(3)).getCurrentConnection(); + verify(mockPluginService, times(2)).isInTransaction(); + verify(mockCacheConn, times(2)).readFromCache("mysql_null_user1_select * from A"); + verify(mockPluginService, times(3)).getSessionStateService(); + verify(mockSessionStateService, times(3)).getCatalog(); + verify(mockSessionStateService, times(3)).getSchema(); + verify(mockConnection).getCatalog(); + verify(mockConnection).getSchema(); + verify(mockSessionStateService).setCatalog("mysql"); + verify(mockDbMetadata).getUserName(); + verify(mockCallable).call(); + verify(mockCacheConn).writeToCache(eq("mysql_null_user1_select * from A"), any(), eq(50)); + verify(mockTotalQueryCounter, times(2)).inc(); + verify(mockCacheMissCounter, times(1)).inc(); + verify(mockCacheHitCounter, times(1)).inc(); + verify(mockCacheBypassCounter, never()).inc(); + // Verify TelemetryContext behavior for cache miss and hit scenario + // First call: Cache miss + Database call + verify(mockTelemetryFactory, times(2)).openTelemetryContext(eq("jdbc-cache-lookup"), eq(TelemetryTraceLevel.TOP_LEVEL)); + verify(mockTelemetryFactory, times(1)).openTelemetryContext(eq("jdbc-database-query"), eq(TelemetryTraceLevel.TOP_LEVEL)); + // Cache context calls: 1 miss (setSuccess(false)) + 1 hit (setSuccess(true)) + verify(mockTelemetryContext, times(1)).setSuccess(false); // Cache miss + verify(mockTelemetryContext, times(1)).setSuccess(true); // Cache hit + // Context closure: 2 cache contexts + 1 database context = 3 total + verify(mockTelemetryContext, times(3)).closeContext(); + } + @Test void test_transaction_cacheQuery() throws Exception { props.setProperty("user", "dbuser"); From fd6b4327d242615cfc59f005004cde86756c1032 Mon Sep 17 00:00:00 2001 From: Shaopeng Gu Date: Mon, 24 Nov 2025 18:40:48 -0800 Subject: [PATCH 24/24] updated convertToTime function to handle offset time with calender and updated corresponding test cases to be DST independent --- .../amazon/jdbc/plugin/cache/CachedResultSet.java | 14 ++++++++++++-- .../jdbc/plugin/cache/CachedResultSetTest.java | 11 +++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java index b4ecbfaf1..58d466ea2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/cache/CachedResultSet.java @@ -377,8 +377,18 @@ private Time convertToTime(Object timeObj, Calendar cal) throws SQLException { return Time.valueOf(targetZonedDateTime.toLocalTime()); } if (timeObj instanceof OffsetTime) { - OffsetTime localTime = ((OffsetTime)timeObj).withOffsetSameInstant(OffsetDateTime.now().getOffset()); - return Time.valueOf(localTime.toLocalTime()); + OffsetTime offsetTime = (OffsetTime) timeObj; + if (cal == null) { + // Convert to default timezone using ZonedDateTime conversion + ZonedDateTime zonedDateTime = offsetTime.atDate(LocalDate.now()) + .atZoneSameInstant(defaultTimeZoneId); + return Time.valueOf(zonedDateTime.toLocalTime()); + } else { + // Convert to specified calendar timezone + ZonedDateTime zonedDateTime = offsetTime.atDate(LocalDate.now()) + .atZoneSameInstant(cal.getTimeZone().toZoneId()); + return Time.valueOf(zonedDateTime.toLocalTime()); + } } if (timeObj instanceof Timestamp) { Timestamp timestamp = (Timestamp) timeObj; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java index 00ac96e34..eaa1587a2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/cache/CachedResultSetTest.java @@ -398,8 +398,15 @@ void test_get_special_time() throws SQLException { assertEquals(Time.valueOf("07:20:30"), cachedRs.getTime(1, estCal)); // Time from OffsetTime assertTrue(cachedRs.next()); - assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1)); - assertEquals(Time.valueOf("05:15:30"), cachedRs.getTime(1, estCal)); + // OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC) converted to default timezone + OffsetTime offsetTimeUtc = OffsetTime.of(12, 15, 30, 0, ZoneOffset.UTC); + ZonedDateTime expectedDefaultTz = offsetTimeUtc.atDate(LocalDate.now()) + .atZoneSameInstant(defaultTimeZone.toZoneId()); + assertEquals(Time.valueOf(expectedDefaultTz.toLocalTime()), cachedRs.getTime(1)); + // OffsetTime converted to EST timezone + ZonedDateTime expectedEstTz = offsetTimeUtc.atDate(LocalDate.now()) + .atZoneSameInstant(estCal.getTimeZone().toZoneId()); + assertEquals(Time.valueOf(expectedEstTz.toLocalTime()), cachedRs.getTime(1, estCal)); // Time from Timestamp assertTrue(cachedRs.next()); Timestamp timestampOne = new Timestamp(1755621000000L);