Java Code Examples for org.nd4j.linalg.primitives.Pair#getRight()

The following examples show how to use org.nd4j.linalg.primitives.Pair#getRight() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: JsonMappers.java    From DataVec with Apache License 2.0 6 votes vote down vote up
/**
 * Register a set of classes (Layer, GraphVertex, InputPreProcessor, IActivation, ILossFunction, ReconstructionDistribution
 * ONLY) for JSON deserialization, with custom names.<br>
 * Using this method directly should never be required (instead: use {@link #registerLegacyCustomClassesForJSON(Class[])}
 * but is added in case it is required in non-standard circumstances.
 */
public static void registerLegacyCustomClassesForJSON(List<Pair<String,Class>> classes){
    for(Pair<String,Class> p : classes){
        String s = p.getFirst();
        Class c = p.getRight();
        //Check if it's a valid class to register...
        boolean found = false;
        for( Class<?> c2 : REGISTERABLE_CUSTOM_CLASSES){
            if(c2.isAssignableFrom(c)){
                Map<String,String> map = LegacyMappingHelper.legacyMappingForClass(c2);
                map.put(p.getFirst(), p.getSecond().getName());
                found = true;
            }
        }

        if(!found){
            throw new IllegalArgumentException("Cannot register class for legacy JSON deserialization: class " +
                    c.getName() + " is not a subtype of classes " + REGISTERABLE_CUSTOM_CLASSES);
        }
    }
}
 
Example 2
Source File: RecordReaderFunction.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public List<Writable> apply(Pair<String, InputStream> value) {
    URI uri = URI.create(value.getFirst());
    InputStream ds = value.getRight();
    try (DataInputStream dis = (DataInputStream) ds) {
        return recordReader.record(uri, dis);
    } catch (IOException e) {
        throw new IllegalStateException("Something went wrong reading file");
    }

}
 
Example 3
Source File: SequenceRecordReaderFunction.java    From DataVec with Apache License 2.0 5 votes vote down vote up
@Override
public List<List<Writable>> apply(Pair<String, InputStream> value) {
    URI uri = URI.create(value.getFirst());
    try (DataInputStream dis = (DataInputStream) value.getRight()) {
        return sequenceRecordReader.sequenceRecord(uri, dis);
    } catch (IOException e) {
        e.printStackTrace();
    }

    throw new IllegalStateException("Something went wrong");
}
 
Example 4
Source File: NDArrayMessage.java    From nd4j with Apache License 2.0 4 votes vote down vote up
/**
 * Convert a direct buffer to an ndarray
 * message.
 * The format of the byte buffer is:
 * ndarray
 * time
 * index
 * dimension length
 * dimensions
 *
 * We use {@link AeronNDArraySerde#toArrayAndByteBuffer(DirectBuffer, int)}
 * to read in the ndarray and just use normal {@link ByteBuffer#getInt()} and
 * {@link ByteBuffer#getLong()} to get the things like dimensions and index
 * and time stamp.
 *
 *
 *
 * @param buffer the buffer to convert
 * @param offset  the offset to start at with the buffer - note that this
 *                method call assumes that the message opType is specified at the beginning of the buffer.
 *                This means whatever offset you pass in will be increased by 4 (the size of an int)
 * @return the ndarray message based on this direct buffer.
 */
public static NDArrayMessage fromBuffer(DirectBuffer buffer, int offset) {
    //skip the message opType
    Pair<INDArray, ByteBuffer> pair = AeronNDArraySerde.toArrayAndByteBuffer(buffer, offset + 4);
    INDArray arr = pair.getKey();
    Nd4j.getCompressor().decompressi(arr);
    //use the rest of the buffer, of note here the offset is already set, we should only need to use
    ByteBuffer rest = pair.getRight();
    long time = rest.getLong();
    long index = rest.getLong();
    //get the array next for dimensions
    int dimensionLength = rest.getInt();
    if (dimensionLength <= 0)
        throw new IllegalArgumentException("Invalid dimension length " + dimensionLength);
    int[] dimensions = new int[dimensionLength];
    for (int i = 0; i < dimensionLength; i++)
        dimensions[i] = rest.getInt();
    return NDArrayMessage.builder().sent(time).arr(arr).index(index).dimensions(dimensions).build();
}