/* * Copyright (C) 2019 Ryan Murray * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.arrow.flight.spark; import java.io.IOException; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class FlightDataReader implements InputPartitionReader<ColumnarBatch> { private static final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); private FlightClient client; private FlightStream stream; private BufferAllocator allocator = null; private FlightClientFactory clientFactory; private final Ticket ticket; private final Broadcast<FlightDataSourceReader.FactoryOptions> options; private final Location location; private boolean parallel; public FlightDataReader(Broadcast<FlightDataSourceReader.FactoryOptions> options) { this.options = options; this.location = Location.forGrpcInsecure(options.value().getHost(), options.value().getPort()); this.ticket = new Ticket(options.value().getTicket()); } private void start() { if (allocator != null) { return; } FlightDataSourceReader.FactoryOptions options = this.options.getValue(); this.parallel = options.isParallel(); this.allocator = new RootAllocator(); logger.warn("setting up a data reader at host {} and port {} with ticket {}", options.getHost(), options.getPort(), new String(ticket.getBytes())); clientFactory = new FlightClientFactory(location, options.getUsername(), options.getPassword(), parallel); client = clientFactory.apply(); stream = client.getStream(ticket); if (parallel) { logger.debug("doing create action for ticket {}", new String(ticket.getBytes())); client.doAction(new Action("create", ticket.getBytes())).forEachRemaining(Object::toString); logger.debug("completed create action for ticket {}", new String(ticket.getBytes())); } } @Override public boolean next() throws IOException { start(); try { return stream.next(); } catch (Throwable t) { throw new IOException(t); } } @Override public ColumnarBatch get() { start(); ColumnarBatch batch = new ColumnarBatch( stream.getRoot().getFieldVectors() .stream() .map(FlightArrowColumnVector::new) .toArray(ColumnVector[]::new) ); batch.setNumRows(stream.getRoot().getRowCount()); return batch; } @Override public void close() throws IOException { try { if (parallel) { client.doAction(new Action("delete", ticket.getBytes())).forEachRemaining(Object::toString); } AutoCloseables.close(stream, client, clientFactory, allocator); allocator.close(); } catch (Exception e) { throw new IOException(e); } } }