Java Code Examples for org.nd4j.common.primitives.Pair#getRight()

The following examples show how to use org.nd4j.common.primitives.Pair#getRight() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseJsonArrayConverter.java    From konduit-serving with Apache License 2.0 6 votes vote down vote up
protected Pair<Map<Integer, Integer>, List<? extends Map<FieldName, ?>>> doTransformProcessConvertPmmlWithErrors(Schema schema, JsonArray jsonArray, TransformProcess transformProcess, DataPipelineErrorHandler dataPipelineErrorHandler) {
    Schema outputSchema = transformProcess.getFinalSchema();

    if (!transformProcess.getInitialSchema().equals(schema)) {
        throw new IllegalArgumentException("Transform process specified, but does not match target input inputSchema");
    }


    List<Map<FieldName, Object>> ret = new ArrayList<>(jsonArray.size());
    List<FieldName> fieldNames = getNameRepresentationFor(outputSchema);

    Pair<Map<Integer, Integer>, ArrowWritableRecordBatch> convertWithErrors = convertWithErrors(schema, jsonArray, transformProcess, dataPipelineErrorHandler);
    ArrowWritableRecordBatch conversion = convertWithErrors.getRight();
    for (int i = 0; i < conversion.size(); i++) {
        List<Writable> recordToMap = conversion.get(i);
        Map<FieldName, Object> record = new LinkedHashMap();
        for (int j = 0; j < outputSchema.numColumns(); j++) {
            record.put(fieldNames.get(j), WritableValueRetriever.getUnderlyingValue(recordToMap.get(j)));

        }

        ret.add(record);
    }

    return Pair.of(convertWithErrors.getKey(), ret);
}
 
Example 2
Source File: RecordReaderFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> apply(Pair<String, InputStream> value) {
    URI uri = URI.create(value.getFirst());
    InputStream ds = value.getRight();
    try (DataInputStream dis = (DataInputStream) ds) {
        return recordReader.record(uri, dis);
    } catch (IOException e) {
        throw new IllegalStateException("Something went wrong reading file");
    }

}
 
Example 3
Source File: SequenceRecordReaderFunction.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Override
public List<List<Writable>> apply(Pair<String, InputStream> value) {
    URI uri = URI.create(value.getFirst());
    try (DataInputStream dis = (DataInputStream) value.getRight()) {
        return sequenceRecordReader.sequenceRecord(uri, dis);
    } catch (IOException e) {
        log.error("",e);
    }

    throw new IllegalStateException("Something went wrong");
}
 
Example 4
Source File: BatchInputParser.java    From konduit-serving with Apache License 2.0 4 votes vote down vote up
/**
 * Returns a list of {@link BatchPartInfo}
 * for each part by name.
 * The "name" is meant to match 1
 * name per input in to a computation graph
 * such that each part name is:
 * inputName[index]
 *
 * @param ctx the context to get the part info
 *            from
 * @return a map indexing part name to a list of parts
 * for each input
 */
private Map<String, List<BatchPartInfo>> partInfoForUploads(RoutingContext ctx) {
    if (ctx.fileUploads().isEmpty()) {
        throw new IllegalStateException("No files found for part info!");
    } else {
        log.debug("Found " + ctx.fileUploads().size() + " file uploads");
    }

    Map<String, List<BatchPartInfo>> ret = new LinkedHashMap<>();
    //parse each file upload all at once
    for (FileUpload upload : ctx.fileUploads()) {
        //the part name: inputName[index]
        String name = upload.name();
        //likely a colon for a tensorflow name got passed in
        //verify against the name in the configuration and set it to that
        if (name.contains(" ")) {
            name = name.replace(" ", ":");
            String inputName = name;
            if(inputName.contains("[")) {
                inputName = inputName.substring(0, name.lastIndexOf("["));
            }
            if (!inputParts.contains(inputName)) {
                throw new IllegalStateException("Illegal name for multi part passed in " + upload.name());
            } else {
                log.warn("Corrected input name " + upload.name() + " to " + name);
            }
        }

        //split the input name and the index
        Pair<String, Integer> partNameAndIndex = partNameAndIndex(name);
        //the part info for this particular file
        BatchPartInfo batchPartInfo = new BatchPartInfo(
                partNameAndIndex.getRight(), upload.uploadedFileName(), name);
        //add the input name and accumulate the part info for each input
        if (!ret.containsKey(partNameAndIndex.getFirst())) {
            ret.put(partNameAndIndex.getFirst(), new ArrayList<>());
        }

        List<BatchPartInfo> batchPartInfos = ret.get(partNameAndIndex.getFirst());
        batchPartInfos.add(batchPartInfo);
    }

    //sort based on index
    for (List<BatchPartInfo> info : ret.values()) {
        Collections.sort(info);
    }

    return ret;
}
 
Example 5
Source File: NDArrayMessage.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
/**
 * Convert a direct buffer to an ndarray
 * message.
 * The format of the byte buffer is:
 * ndarray
 * time
 * index
 * dimension length
 * dimensions
 *
 * We use {@link AeronNDArraySerde#toArrayAndByteBuffer(DirectBuffer, int)}
 * to read in the ndarray and just use normal {@link ByteBuffer#getInt()} and
 * {@link ByteBuffer#getLong()} to get the things like dimensions and index
 * and time stamp.
 *
 *
 *
 * @param buffer the buffer to convert
 * @param offset  the offset to start at with the buffer - note that this
 *                method call assumes that the message opType is specified at the beginning of the buffer.
 *                This means whatever offset you pass in will be increased by 4 (the size of an int)
 * @return the ndarray message based on this direct buffer.
 */
public static NDArrayMessage fromBuffer(DirectBuffer buffer, int offset) {
    //skip the message opType
    Pair<INDArray, ByteBuffer> pair = AeronNDArraySerde.toArrayAndByteBuffer(buffer, offset + 4);
    INDArray arr = pair.getKey();
    Nd4j.getCompressor().decompressi(arr);
    //use the rest of the buffer, of note here the offset is already set, we should only need to use
    ByteBuffer rest = pair.getRight();
    long time = rest.getLong();
    long index = rest.getLong();
    //get the array next for dimensions
    int dimensionLength = rest.getInt();
    if (dimensionLength <= 0)
        throw new IllegalArgumentException("Invalid dimension length " + dimensionLength);
    int[] dimensions = new int[dimensionLength];
    for (int i = 0; i < dimensionLength; i++)
        dimensions[i] = rest.getInt();
    return NDArrayMessage.builder().sent(time).arr(arr).index(index).dimensions(dimensions).build();
}