/* * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.nvidia.grcuda.test; import java.io.BufferedReader; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertEquals; import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Value; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TemporaryFolder; public class BindTest { /** CUDA C source code of incrementing kernel. */ private static final String CXX_SOURCE = " \n" + "// C kernels \n" + "__global__ void inc_kernel(float *out_arr, const int *in_arr, int num_elements) { \n" + " for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; \n" + " idx += gridDim.x * blockDim.x) { \n" + " out_arr[idx] = in_arr[idx] + 1.0f; \n" + " } \n" + "} \n" + "__global__ void inc_inplace_kernel(int *inout_arr, int num_elements) { \n" + " for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; \n" + " idx += gridDim.x * blockDim.x) { \n" + " inout_arr[idx] += 1; \n" + " } \n" + "} \n" + "\n" + "// C functions \n" + "extern \"C\" int inc_host(int blocks, int threads_per_block, \n" + " float *out_arr, const int *in_arr, int num_elements) { \n" + " inc_kernel<<<blocks, threads_per_block>>>(out_arr, in_arr, num_elements); \n" + " return cudaDeviceSynchronize(); \n" + "} \n" + "extern \"C\" int inc_inplace_host(int blocks, int threads_per_block, \n" + " int *inout_arr, int num_elements) { \n" + " inc_inplace_kernel<<<blocks, threads_per_block>>>(inout_arr, num_elements); \n" + " return cudaDeviceSynchronize(); \n" + "} \n" + "// C++ functions \n" + "int cxx_inc_host(int blocks, int threads_per_block, \n" + " float *out_arr, const int *in_arr, int num_elements) { \n" + " inc_kernel<<<blocks, threads_per_block>>>(out_arr, in_arr, num_elements); \n" + " return cudaDeviceSynchronize(); \n" + "} \n" + "int cxx_inc_inplace_host(int blocks, int threads_per_block, \n" + " int *inout_arr, int num_elements) { \n" + " inc_inplace_kernel<<<blocks, threads_per_block>>>(inout_arr, num_elements); \n" + " return cudaDeviceSynchronize(); \n" + "}\n"; private static final int numElements = 100; @ClassRule public static TemporaryFolder tempFolder = new TemporaryFolder(); public static String dynamicLibraryFile; @BeforeClass public static void setupUpClass() throws IOException, InterruptedException { // Write CUDA C source file File sourceFile = tempFolder.newFile("inc_kernel.cu"); PrintWriter writer = new PrintWriter(new FileWriter(sourceFile)); writer.write(CXX_SOURCE); writer.close(); dynamicLibraryFile = sourceFile.getParent() + File.separator + "libfoo.so"; // Compile source file with NVCC Process compiler = Runtime.getRuntime().exec("nvcc -shared -Xcompiler -fPIC " + sourceFile.getAbsolutePath() + " -o " + dynamicLibraryFile); BufferedReader output = new BufferedReader(new InputStreamReader(compiler.getErrorStream())); int nvccReturnCode = compiler.waitFor(); output.lines().forEach(System.out::println); assertEquals(0, nvccReturnCode); } public void callWithInAndOutArguments(String... bindArgs) { try (Context polyglot = Context.newBuilder().allowAllAccess(true).build()) { Value cu = polyglot.eval("grcuda", "CU"); Value inDeviceArray = cu.getMember("DeviceArray").execute("int", numElements); Value outDeviceArray = cu.getMember("DeviceArray").execute("float", numElements); for (int i = 0; i < numElements; i++) { inDeviceArray.setArrayElement(i, Integer.valueOf(i)); outDeviceArray.setArrayElement(i, Float.valueOf(0)); } // get function from shared library Value bind = cu.getMember("bind"); Value function = bindArgs.length > 1 ? bind.execute(dynamicLibraryFile, bindArgs[0], bindArgs[1]) : bind.execute(dynamicLibraryFile, bindArgs[0]); assertNotNull(function); // call function int blocks = 80; int threadsPerBlock = 256; function.execute(blocks, threadsPerBlock, outDeviceArray, inDeviceArray, numElements); // verify result for (int i = 0; i < numElements; i++) { assertEquals(i + 1.0f, outDeviceArray.getArrayElement(i).asFloat(), 1e-3f); } } } public void callWithInoutArgument(String... bindArgs) { try (Context polyglot = Context.newBuilder().allowAllAccess(true).build()) { Value cu = polyglot.eval("grcuda", "CU"); Value inoutDeviceArray = cu.getMember("DeviceArray").execute("int", numElements); for (int i = 0; i < numElements; i++) { inoutDeviceArray.setArrayElement(i, Integer.valueOf(i)); } // get function from shared library Value bind = cu.getMember("bind"); Value function = bindArgs.length > 1 ? bind.execute(dynamicLibraryFile, bindArgs[0], bindArgs[1]) : bind.execute(dynamicLibraryFile, bindArgs[0]); assertNotNull(function); // call function int blocks = 80; int threadsPerBlock = 256; function.execute(blocks, threadsPerBlock, inoutDeviceArray, numElements); // verify result for (int i = 0; i < numElements; i++) { assertEquals(i + 1, inoutDeviceArray.getArrayElement(i).asInt()); } } } @Test public void testCcallLegacyNFISignatureWithInAndOutArguments() { callWithInAndOutArguments("inc_host", "(sint32, sint32, pointer, pointer, sint32): sint32"); } @Test public void testCcallLegacyNFISignatureWithInoutArgument() { callWithInoutArgument("inc_inplace_host", "(sint32, sint32, pointer, sint32): sint32"); } @Test public void testCxxCallLegacyNFISignatureWithInAndOutArguments() { callWithInAndOutArguments("_Z12cxx_inc_hostiiPfPKii", "(sint32, sint32, pointer, pointer, sint32): sint32"); } @Test public void testCxxCallLegacyNFISignatureWithInoutArgument() { callWithInoutArgument("_Z20cxx_inc_inplace_hostiiPii", "(sint32, sint32, pointer, sint32): sint32"); } @Test public void testCcallNIDLSignatureWithInAndOutArguments() { callWithInAndOutArguments("" + "inc_host(blocks: sint32, threads_per_block: sint32, out_arr: out pointer float, " + "in_arr: in pointer sint32, num_elements: sint32): sint32"); } @Test public void testCcallNIDLSignatureWithInoutArguments() { callWithInoutArgument("" + "inc_inplace_host(blocks: sint32, threads_per_block: sint32, inout_arr: inout pointer sint32, " + "num_elements: sint32): sint32"); } @Test public void testCxxCallNIDLSignatureWithInAndOutArguments() { callWithInAndOutArguments("cxx " + "cxx_inc_host(blocks: sint32, threads_per_block: sint32, out_arr: out pointer float, " + "in_arr: in pointer sint32, num_elements: sint32): sint32"); } @Test public void testCxxcallNIDLSignatureWithInoutArguments() { callWithInoutArgument("cxx " + "cxx_inc_inplace_host(blocks: sint32, threads_per_block: sint32, inout_arr: inout pointer sint32, " + "num_elements: sint32): sint32"); } }