package tberg.murphy.gpu;

import static jcuda.driver.JCudaDriver.cuCtxCreate;
import static jcuda.driver.JCudaDriver.cuDeviceGet;
import static jcuda.driver.JCudaDriver.cuModuleLoad;

import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;

import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;

public class CudaUtil {
	
	public static CUdevice device;
	public static CUcontext context;
	
	public static void startup(int deviceId) {
        JCudaDriver.setExceptionsEnabled(true);
        JCudaDriver.cuInit(0);
        device = new CUdevice();
        cuDeviceGet(device, deviceId);
        context = new CUcontext();
        cuCtxCreate(context, 0, device);
	}
	
	public static void shutdown() {
		JCuda.cudaDeviceReset();
	}
	
	public static CUmodule compileAndLoad(String kernelName, String kernelSrc, boolean forceCompile) {
		return loadModule(preparePtxFile(kernelName, kernelSrc, forceCompile));
	}
	
	public static String preparePtxFile(String kernelName, String kernelSrc, boolean forceCompile) {
		String ptxFileName = kernelName+".ptx";
		try {
			File ptxFile = new File(ptxFileName);
			if (!forceCompile && ptxFile.exists()) {
				return ptxFileName;
			}

			long start = System.nanoTime();
			File cuFile = new File(kernelName+".cu");
			BufferedWriter out = new BufferedWriter(new FileWriter(cuFile));
			out.append(kernelSrc);
			out.flush();
			out.close();

			String modelString = "-m"+System.getProperty("sun.arch.data.model");
			int[] major = new int[1];
			int[] minor = new int[1];
			JCudaDriver.cuDeviceComputeCapability(major, minor, device);
			String command = "nvcc -use_fast_math -arch=sm_"+major[0]+""+minor[0]+" " + modelString + " -ptx "+ cuFile.getPath()+" -o "+ptxFileName;
			System.out.println("Executing\n"+command);
			Process process = Runtime.getRuntime().exec(command);

			String errorMessage =
					new String(toByteArray(process.getErrorStream()));
			String outputMessage =
					new String(toByteArray(process.getInputStream()));
			int exitValue = 0;
			try
			{
				exitValue = process.waitFor();
			}
			catch (InterruptedException e)
			{
				Thread.currentThread().interrupt();
				throw new IOException(
						"Interrupted while waiting for nvcc output", e);
			}

			if (exitValue != 0)
			{
				System.out.println("nvcc process exitValue "+exitValue);
				System.out.println("errorMessage:\n"+errorMessage);
				System.out.println("outputMessage:\n"+outputMessage);
				throw new IOException(
						"Could not create .ptx file: "+errorMessage);
			}
			System.out.println("Finished creating PTX file");

										long end = System.nanoTime();
			System.out.println("Compile time: "+(end - start) / 1e6 + "ms");
		} catch (IOException e) {
			e.printStackTrace();
		}
		return ptxFileName;
    }
	
	public static CUmodule loadModule(String name) {
        CUmodule module = new CUmodule();
        cuModuleLoad(module, name);
        return module;
	}

    private static byte[] toByteArray(InputStream inputStream) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte buffer[] = new byte[8192];
        while (true)
        {
            int read = inputStream.read(buffer);
            if (read == -1)
            {
                break;
            }
            baos.write(buffer, 0, read);
        }
        return baos.toByteArray();
    }
    
	// ignores the higher 16 bits
	public static float toFloat( int hbits )
	{
	    int mant = hbits & 0x03ff;            // 10 bits mantissa
	    int exp =  hbits & 0x7c00;            // 5 bits exponent
	    if( exp == 0x7c00 )                   // NaN/Inf
	        exp = 0x3fc00;                    // -> NaN/Inf
	    else if( exp != 0 )                   // normalized value
	    {
	        exp += 0x1c000;                   // exp - 15 + 127
	        if( mant == 0 && exp > 0x1c400 )  // smooth transition
	            return Float.intBitsToFloat( ( hbits & 0x8000 ) << 16
	                                            | exp << 13 | 0x3ff );
	    }
	    else if( mant != 0 )                  // && exp==0 -> subnormal
	    {
	        exp = 0x1c400;                    // make it normal
	        do {
	            mant <<= 1;                   // mantissa * 2
	            exp -= 0x400;                 // decrease exp by 1
	        } while( ( mant & 0x400 ) == 0 ); // while not normal
	        mant &= 0x3ff;                    // discard subnormal bit
	    }                                     // else +/-0 -> +/-0
	    return Float.intBitsToFloat(          // combine all parts
	        ( hbits & 0x8000 ) << 16          // sign  << ( 31 - 15 )
	        | ( exp | mant ) << 13 );         // value << ( 23 - 10 )
	}
	
	// returns all higher 16 bits as 0 for all results
	public static int fromFloat( float fval )
	{
	    int fbits = Float.floatToIntBits( fval );
	    int sign = fbits >>> 16 & 0x8000;          // sign only
	    int val = ( fbits & 0x7fffffff ) + 0x1000; // rounded value

	    if( val >= 0x47800000 )               // might be or become NaN/Inf
	    {                                     // avoid Inf due to rounding
	        if( ( fbits & 0x7fffffff ) >= 0x47800000 )
	        {                                 // is or must become NaN/Inf
	            if( val < 0x7f800000 )        // was value but too large
	                return sign | 0x7c00;     // make it +/-Inf
	            return sign | 0x7c00 |        // remains +/-Inf or NaN
	                ( fbits & 0x007fffff ) >>> 13; // keep NaN (and Inf) bits
	        }
	        return sign | 0x7bff;             // unrounded not quite Inf
	    }
	    if( val >= 0x38800000 )               // remains normalized value
	        return sign | val - 0x38000000 >>> 13; // exp - 127 + 15
	    if( val < 0x33000000 )                // too small for subnormal
	        return sign;                      // becomes +/-0
	    val = ( fbits & 0x7fffffff ) >>> 23;  // tmp exp for subnormal calc
	    return sign | ( ( fbits & 0x7fffff | 0x800000 ) // add subnormal bit
	         + ( 0x800000 >>> val - 102 )     // round depending on cut off
	      >>> 126 - val );   // div by 2^(1-(exp-127+15)) and >> 13 | exp=0
	}
	
	public static char[] convertToHalfFloat(float[] vect) {
		char[] result = new char[vect.length];
		for (int i=0; i<vect.length; ++i) {
			result[i] = (char) fromFloat(vect[i]);
		}
		return result;
	}
	
	public static float[] convertFromHalfFloat(char[] vect) {
		float[] result = new float[vect.length];
		for (int i=0; i<vect.length; ++i) {
			result[i] = toFloat(vect[i]);
		}
		return result;
	}
	
	public static float[] flatten(float[][] mat) {
		float[] result = new float[mat.length*mat[0].length];
		for (int i=0; i<mat.length; ++i) {
			System.arraycopy(mat[i], 0, result, i*mat[0].length, mat[i].length);
		}
		return result;
	}
	
	public static double[] flatten(double[][] mat) {
		double[] result = new double[mat.length*mat[0].length];
		for (int i=0; i<mat.length; ++i) {
			System.arraycopy(mat[i], 0, result, i*mat[0].length, mat[i].length);
		}
		return result;
	}
	
	public static float[] flatten(float[][][] tens) {
		float[] result = new float[tens.length*tens[0].length*tens[0][0].length];
		for (int i=0; i<tens.length; ++i) {
			for (int j=0; j<tens[0].length; ++j) {
				System.arraycopy(tens[i][j], 0, result, i*tens[0].length*tens[0][0].length + j*tens[0][0].length, tens[i][j].length);
			}
		}
		return result;
	}
	
	public static float[] flatten(List<float[]> mat) {
		float[] result = new float[mat.size()*mat.get(0).length];
		for (int i=0; i<mat.size(); ++i) {
			System.arraycopy(mat.get(i), 0, result, i*mat.get(0).length, mat.get(i).length);
		}
		return result;
	}
	
	public static int flatten(int I, int J, int i, int j) {
		return (i*J)+j;
	}
	
	public static int unflattenFirst(int I, int J, int f) {
		return f / J;
	}
	
	public static int unflattenSecond(int I, int J, int f) {
		return f % J;
	}
	
	public static String flatten(int I, int J, int i, String j) {
		return "("+(i*J)+" + "+j+")";
	}
	
	public static String flatten(int I, int J, String i, int j) {
		return "("+i+" * "+J+" + "+j+")";
	}
	
	public static String flatten(int I, int J, String i, String j) {
		return "("+i+" * "+J+" + "+j+")";
	}

	public static int flatten(int I, int J, int K, int i, int j, int k) {
		return i*J*K+j*K+k;
	}	
	
	public static String flatten(int I, int J, int K, int i, int j, String k) {
		return "("+(i*J*K+j*K)+" + "+k+")";
	}	
	
	public static String flatten(int I, int J, int K, String i, int j, int k) {
		return "("+i+" * "+(J*K) + " + "+(j*K)+" + "+k+")";
	}
	
	public static String flatten(int I, int J, int K, int i, String j, int k) {
		return "("+(i*J*K) + " + "+j+"  * "+K+" + "+k+")";
	}	
	
	public static String flatten(int I, int J, int K, String i, String j, int k) {
		return "("+i+" * "+(J*K) + " + "+j+"  * "+K+" + "+k+")";
	}	
	
	public static String flatten(int I, int J, int K, String i, int j, String k) {
		return "("+i+" * "+(J*K) + " + "+(j*K)+" + "+k+")";
	}

	public static String flatten(int I, int J, int K, int i, String j, String k) {
		return "("+(i*J*K) + " + "+j+"  * "+K+" + "+k+")";
	}	
	
	public static String flatten(int I, int J, int K, String i, String j, String k) {
		return "("+i+" * "+(J*K) + " + "+j+"  * "+K+" + "+k+")";
	}	
	
	public static float[] extendWithZeros(float[] x, int l) {
		float[] result = new float[l];
		for (int i=0; i<l; ++i) {
			if (i < x.length) {
				result[i] = x[i];
			} else {
				result[i] = 0.0f;
			}
		}
		return result;
	}

}