/*- * #%L * Amazon Athena Query Federation SDK * %% * Copyright (C) 2019 - 2020 Amazon Web Services * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ package com.amazonaws.athena.connector.lambda.serde.v2; import com.amazonaws.athena.connector.lambda.data.BlockAllocator; import com.amazonaws.athena.connector.lambda.serde.BaseDeserializer; import com.amazonaws.athena.connector.lambda.serde.BaseSerializer; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.SerializerProvider; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.MessageSerializer; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.channels.Channels; import java.util.concurrent.atomic.AtomicReference; import static java.util.Objects.requireNonNull; public final class ArrowRecordBatchSerDe { private ArrowRecordBatchSerDe(){} static final class Serializer extends BaseSerializer<ArrowRecordBatch> { Serializer() { super(ArrowRecordBatch.class); } @Override protected void doSerialize(ArrowRecordBatch arrowRecordBatch, JsonGenerator jgen, SerializerProvider provider) throws IOException { try { ByteArrayOutputStream out = new ByteArrayOutputStream(); MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), arrowRecordBatch); jgen.writeBinary(out.toByteArray()); } finally { arrowRecordBatch.close(); } } } static final class Deserializer extends BaseDeserializer<ArrowRecordBatch> { private final BlockAllocator blockAllocator; Deserializer(BlockAllocator allocator) { super(ArrowRecordBatch.class); this.blockAllocator = requireNonNull(allocator, "allocator is null"); } @Override protected ArrowRecordBatch doDeserialize(JsonParser jparser, DeserializationContext ctxt) throws IOException { if (jparser.nextToken() != JsonToken.VALUE_EMBEDDED_OBJECT) { throw new IllegalStateException("Expecting " + JsonToken.VALUE_STRING + " but found " + jparser.getCurrentLocation()); } byte[] bytes = jparser.getBinaryValue(); AtomicReference<ArrowRecordBatch> batch = new AtomicReference<>(); try { return blockAllocator.registerBatch((BufferAllocator root) -> { batch.set((ArrowRecordBatch) MessageSerializer.deserializeMessageBatch( new ReadChannel(Channels.newChannel(new ByteArrayInputStream(bytes))), root)); return batch.get(); }); } catch (Exception ex) { if (batch.get() != null) { batch.get().close(); } throw ex; } } } }