/*- * #%L * athena-jdbc * %% * Copyright (C) 2019 Amazon Web Services * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ package com.amazonaws.connectors.athena.jdbc.manager; import com.amazonaws.athena.connector.lambda.QueryStatusChecker; import com.amazonaws.athena.connector.lambda.data.Block; import com.amazonaws.athena.connector.lambda.data.BlockSpiller; import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter; import com.amazonaws.athena.connector.lambda.data.writers.extractors.BigIntExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.BitExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.DateDayExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.DateMilliExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.DecimalExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.Float4Extractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.Float8Extractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.IntExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.SmallIntExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.TinyIntExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.VarBinaryExtractor; import com.amazonaws.athena.connector.lambda.data.writers.extractors.VarCharExtractor; import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableDecimalHolder; import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableVarBinaryHolder; import com.amazonaws.athena.connector.lambda.data.writers.holders.NullableVarCharHolder; import com.amazonaws.athena.connector.lambda.domain.Split; import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints; import com.amazonaws.athena.connector.lambda.handlers.RecordHandler; import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest; import com.amazonaws.connectors.athena.jdbc.connection.DatabaseConnectionConfig; import com.amazonaws.connectors.athena.jdbc.connection.JdbcConnectionFactory; import com.amazonaws.connectors.athena.jdbc.connection.JdbcCredentialProvider; import com.amazonaws.connectors.athena.jdbc.connection.RdsSecretsCredentialProvider; import com.amazonaws.services.athena.AmazonAthena; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.secretsmanager.AWSSecretsManager; import org.apache.arrow.vector.holders.NullableBigIntHolder; import org.apache.arrow.vector.holders.NullableBitHolder; import org.apache.arrow.vector.holders.NullableDateDayHolder; import org.apache.arrow.vector.holders.NullableDateMilliHolder; import org.apache.arrow.vector.holders.NullableFloat4Holder; import org.apache.arrow.vector.holders.NullableFloat8Holder; import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableSmallIntHolder; import org.apache.arrow.vector.holders.NullableTinyIntHolder; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.Validate; import org.joda.time.DateTime; import org.joda.time.Days; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.Date; import java.util.Map; /** * Abstracts JDBC record handler and provides common reusable split records handling. */ public abstract class JdbcRecordHandler extends RecordHandler { public static final org.joda.time.MutableDateTime EPOCH = new org.joda.time.MutableDateTime(); private static final Logger LOGGER = LoggerFactory.getLogger(JdbcRecordHandler.class); private final JdbcConnectionFactory jdbcConnectionFactory; private final DatabaseConnectionConfig databaseConnectionConfig; /** * Used only by Multiplexing handler. All invocations will be delegated to respective database handler. */ protected JdbcRecordHandler() { super(null); this.jdbcConnectionFactory = null; this.databaseConnectionConfig = null; } protected JdbcRecordHandler(final AmazonS3 amazonS3, final AWSSecretsManager secretsManager, AmazonAthena athena, final DatabaseConnectionConfig databaseConnectionConfig, final JdbcConnectionFactory jdbcConnectionFactory) { super(amazonS3, secretsManager, athena, databaseConnectionConfig.getType().getDbName()); this.jdbcConnectionFactory = Validate.notNull(jdbcConnectionFactory, "jdbcConnectionFactory must not be null"); this.databaseConnectionConfig = Validate.notNull(databaseConnectionConfig, "databaseConnectionConfig must not be null"); } private JdbcCredentialProvider getCredentialProvider() { final String secretName = this.databaseConnectionConfig.getSecret(); if (StringUtils.isNotBlank(secretName)) { return new RdsSecretsCredentialProvider(getSecret(secretName)); } return null; } @Override public void readWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest readRecordsRequest, QueryStatusChecker queryStatusChecker) { LOGGER.info("{}: Catalog: {}, table {}, splits {}", readRecordsRequest.getQueryId(), readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(), readRecordsRequest.getSplit().getProperties()); try (Connection connection = this.jdbcConnectionFactory.getConnection(getCredentialProvider())) { connection.setAutoCommit(false); // For consistency. This is needed to be false to enable streaming for some database types. try (PreparedStatement preparedStatement = buildSplitSql(connection, readRecordsRequest.getCatalogName(), readRecordsRequest.getTableName(), readRecordsRequest.getSchema(), readRecordsRequest.getConstraints(), readRecordsRequest.getSplit()); ResultSet resultSet = preparedStatement.executeQuery()) { Map<String, String> partitionValues = readRecordsRequest.getSplit().getProperties(); GeneratedRowWriter.RowWriterBuilder rowWriterBuilder = GeneratedRowWriter.newBuilder(readRecordsRequest.getConstraints()); for (Field next : readRecordsRequest.getSchema().getFields()) { Extractor extractor = makeExtractor(next, resultSet, partitionValues); rowWriterBuilder.withExtractor(next.getName(), extractor); } GeneratedRowWriter rowWriter = rowWriterBuilder.build(); int rowsReturnedFromDatabase = 0; while (resultSet.next()) { if (!queryStatusChecker.isQueryRunning()) { return; } blockSpiller.writeRows((Block block, int rowNum) -> rowWriter.writeRow(block, rowNum, resultSet) ? 1 : 0); rowsReturnedFromDatabase++; } LOGGER.info("{} rows returned by database.", rowsReturnedFromDatabase); connection.commit(); } } catch (SQLException sqlException) { throw new RuntimeException(sqlException.getErrorCode() + ": " + sqlException.getMessage(), sqlException); } } /** * Creates an Extractor for the given field. In this example the extractor just creates some random data. */ private Extractor makeExtractor(Field field, ResultSet resultSet, Map<String, String> partitionValues) { Types.MinorType fieldType = Types.getMinorTypeForArrowType(field.getType()); final String fieldName = field.getName(); if (partitionValues.containsKey(fieldName)) { return (VarCharExtractor) (Object context, NullableVarCharHolder dst) -> { dst.isSet = 1; dst.value = partitionValues.get(fieldName); }; } switch (fieldType) { case BIT: return (BitExtractor) (Object context, NullableBitHolder dst) -> { boolean value = resultSet.getBoolean(fieldName); dst.value = value ? 1 : 0; dst.isSet = resultSet.wasNull() ? 0 : 1; }; case TINYINT: return (TinyIntExtractor) (Object context, NullableTinyIntHolder dst) -> { dst.value = resultSet.getByte(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case SMALLINT: return (SmallIntExtractor) (Object context, NullableSmallIntHolder dst) -> { dst.value = resultSet.getShort(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case INT: return (IntExtractor) (Object context, NullableIntHolder dst) -> { dst.value = resultSet.getInt(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case BIGINT: return (BigIntExtractor) (Object context, NullableBigIntHolder dst) -> { dst.value = resultSet.getLong(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case FLOAT4: return (Float4Extractor) (Object context, NullableFloat4Holder dst) -> { dst.value = resultSet.getFloat(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case FLOAT8: return (Float8Extractor) (Object context, NullableFloat8Holder dst) -> { dst.value = resultSet.getDouble(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case DECIMAL: return (DecimalExtractor) (Object context, NullableDecimalHolder dst) -> { dst.value = resultSet.getBigDecimal(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case DATEDAY: return (DateDayExtractor) (Object context, NullableDateDayHolder dst) -> { if (resultSet.getDate(fieldName) != null) { dst.value = Days.daysBetween(EPOCH, new DateTime(((Date) resultSet.getDate(fieldName)).getTime())).getDays(); } dst.isSet = resultSet.wasNull() ? 0 : 1; }; case DATEMILLI: return (DateMilliExtractor) (Object context, NullableDateMilliHolder dst) -> { if (resultSet.getTimestamp(fieldName) != null) { dst.value = resultSet.getTimestamp(fieldName).getTime(); } dst.isSet = resultSet.wasNull() ? 0 : 1; }; case VARCHAR: return (VarCharExtractor) (Object context, NullableVarCharHolder dst) -> { dst.value = resultSet.getString(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; case VARBINARY: return (VarBinaryExtractor) (Object context, NullableVarBinaryHolder dst) -> { dst.value = resultSet.getBytes(fieldName); dst.isSet = resultSet.wasNull() ? 0 : 1; }; default: throw new RuntimeException("Unhandled type " + fieldType); } } /** * Builds split SQL string and returns prepared statement. * * @param jdbcConnection jdbc connection. See {@link Connection} * @param catalogName Athena provided catalog name. * @param tableName database table name. * @param schema table schema. * @param constraints constraints to push down to the database. * @param split table split. * @return prepared statement with sql. See {@link PreparedStatement} * @throws SQLException JDBC database exception. */ public abstract PreparedStatement buildSplitSql(Connection jdbcConnection, String catalogName, TableName tableName, Schema schema, Constraints constraints, Split split) throws SQLException; }