/* * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.nvidia.grcuda.test; import java.io.BufferedReader; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Value; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.rules.TemporaryFolder; public class BindKernelTest { /** CUDA C source code of incrementing kernel. */ private static final String INCREMENT_KERNEL_SOURCE = "extern \"C\" \n" + "__global__ void inc_kernel(int *out_arr, const int *in_arr, int num_elements) { \n" + " for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; \n" + " idx += gridDim.x * blockDim.x) { \n" + " out_arr[idx] = in_arr[idx] + 1; \n" + " } \n" + "}\n" + "\n" + "__global__ void cxx_inc_kernel(int *out_arr, const int *in_arr, int num_elements) { \n" + " for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; \n" + " idx += gridDim.x * blockDim.x) { \n" + " out_arr[idx] = in_arr[idx] + 1; \n" + " } \n" + "}\n"; private static final int NUM_ELEMENTS = 1000; @ClassRule public static TemporaryFolder tempFolder = new TemporaryFolder(); public static String ptxFileName; @BeforeClass public static void setupUpClass() throws IOException, InterruptedException { // Write CUDA C source file File sourceFile = tempFolder.newFile("inc_kernel.cu"); PrintWriter writer = new PrintWriter(new FileWriter(sourceFile)); writer.write(INCREMENT_KERNEL_SOURCE); writer.close(); BindKernelTest.ptxFileName = sourceFile.getParent() + File.separator + "inc_kernel.ptx"; // Compile source file with NVCC Process compiler = Runtime.getRuntime().exec("nvcc --ptx " + sourceFile.getAbsolutePath() + " -o " + BindKernelTest.ptxFileName); BufferedReader output = new BufferedReader(new InputStreamReader(compiler.getErrorStream())); int nvccReturnCode = compiler.waitFor(); output.lines().forEach(System.out::println); assertEquals(0, nvccReturnCode); } void testWithSignature(String... bindArgs) { // Build inc_kernel symbol, launch it, and check results. try (Context context = Context.newBuilder().allowAllAccess(true).build()) { Value deviceArrayConstructor = context.eval("grcuda", "DeviceArray"); Value bindkernel = context.eval("grcuda", "bindkernel"); Value incKernel = bindArgs.length > 1 ? bindkernel.execute(BindKernelTest.ptxFileName, bindArgs[0], bindArgs[1]) : bindkernel.execute(BindKernelTest.ptxFileName, bindArgs[0]); assertNotNull(incKernel); assertTrue(incKernel.canExecute()); assertEquals(0, incKernel.getMember("launchCount").asInt()); assertNotNull(incKernel.getMember("ptx").asString()); Value inDevArray = deviceArrayConstructor.execute("int", NUM_ELEMENTS); Value outDevArray = deviceArrayConstructor.execute("int", NUM_ELEMENTS); for (int i = 0; i < NUM_ELEMENTS; ++i) { inDevArray.setArrayElement(i, i); outDevArray.setArrayElement(i, 0); } // <<<8, 128>>> 8 blocks a 128 threads Value configuredIncKernel = incKernel.execute(8, 128); assertTrue(configuredIncKernel.canExecute()); configuredIncKernel.execute(outDevArray, inDevArray, NUM_ELEMENTS); // implicit synchronization // verify result for (int i = 0; i < NUM_ELEMENTS; ++i) { assertEquals(i, inDevArray.getArrayElement(i).asInt()); assertEquals(i + 1, outDevArray.getArrayElement(i).asInt()); } assertEquals(1, incKernel.getMember("launchCount").asInt()); } } @Test public void testBindKernelWithLegacyNFISignatureToCKernel() { testWithSignature("inc_kernel", "pointer, pointer, sint32"); } @Test public void testBindKernelWithLegacyNFISignatureToCxxKernel() { testWithSignature("_Z14cxx_inc_kernelPiPKii", "pointer, pointer, sint32"); } @Test public void testBindKernelWithNDILSignatureToCKernel() { testWithSignature("inc_kernel(out_arr: out pointer sint32, in_arr: in pointer sint32, num_elements: sint32)"); } @Test public void testBindKernelWithNDILSignatureToCxxKernel() { testWithSignature("cxx cxx_inc_kernel(out_arr: out pointer sint32, in_arr: in pointer sint32, num_elements: sint32)"); } }