package edu.cmu.sv.kelinci;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.BindException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.Arrays;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * 
 * @author rodykers
 *
 */
class Kelinci {

	private static final int maxQueue = 10;
	private static Queue<FuzzRequest> requestQueue = new ConcurrentLinkedQueue<>();

	public static final byte STATUS_SUCCESS = 0;
	public static final byte STATUS_TIMEOUT = 1;
	public static final byte STATUS_CRASH = 2;
	public static final byte STATUS_QUEUE_FULL = 3;
	public static final byte STATUS_COMM_ERROR = 4;
	public static final byte STATUS_DONE = 5;

	public static final long DEFAULT_TIMEOUT = 300000L; // in milliseconds
	private static long timeout;
	
	public static final int DEFAULT_VERBOSITY = 2;
	private static int verbosity;
	
	public static final int DEFAULT_PORT = 7007;
	private static int port;
	
	public static final byte DEFAULT_MODE = 0;
	public static final byte LOCAL_MODE = 1;

	private static Method targetMain;
	private static String targetArgs[];

	private static File tmpfile;

	private static class FuzzRequest {
		Socket clientSocket;

		FuzzRequest(Socket clientSocket) {
			this.clientSocket = clientSocket;
		}
	}

	/**
	 * Method to run in a thread to accept requests coming
	 * in over TCP and put them in a queue.
	 */
	private static void runServer() {

		try (ServerSocket ss = new ServerSocket(port)) {
			if (verbosity > 1)
				System.out.println("Server listening on port " + port);

			while (true) {
				Socket s = ss.accept();
				if (verbosity > 1)
					System.out.println("Connection established.");

				boolean status = false;
				if (requestQueue.size() < maxQueue) {
					status = requestQueue.offer(new FuzzRequest(s));
					if (verbosity > 1)
						System.out.println("Request added to queue: " + status);
				} 
				if (!status) {
					if (verbosity > 1)
						System.out.println("Queue full.");
					OutputStream os = s.getOutputStream();
					os.write(STATUS_QUEUE_FULL);
					os.flush();
					s.shutdownOutput();
					s.shutdownInput();
					s.close();
					if (verbosity > 1)
						System.out.println("Connection closed.");
				}
			}
		} catch (BindException be) {
			System.err.println("Unable to bind to port " + port);
			System.exit(1);
		} catch (Exception e) {
			System.err.println("Exception in request server");
			e.printStackTrace();
			System.exit(1);
		}
	}
	
	/**
	 * Calls main() with the provided file name,replaces @@ by the file name.
	 */
	private static long runApplication(String filename) {
		long runtime = -1L;

		String[] args = Arrays.copyOf(targetArgs, targetArgs.length);
		for (int i = 0; i < args.length; i++) {
			if (args[i].equals("@@")) {
				args[i] = filename;
			}
		}
		long pre = System.nanoTime();
		try {
			targetMain.invoke(null, (Object) args);
		} catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
			e.printStackTrace();
			throw new RuntimeException("Error invoking target main method");
		}
		runtime = System.nanoTime() - pre;
		return runtime;
	}

	/**
	 * Writes the provided input to a file, then calls main().
	 * Replaces @@ by the tmp file name.
	 * 
	 * @param input The file contents as a byte array.
	 */
	private static long runApplication(byte input[]) {
		try (FileOutputStream stream = new FileOutputStream(tmpfile)) {			
			stream.write(input);
			stream.close();
			return runApplication(tmpfile.getAbsolutePath());
		} catch (IOException ioe) {
			throw new RuntimeException("Error writing to tmp file");
		}
	}

	/**
	 * Runner thread for the target application.
	 *
	 */
	private static class ApplicationCall implements Callable<Long> {
		byte input[];
		String path;

		ApplicationCall(byte input[]) {
			this.input = input;
		}
		
		ApplicationCall(String path) {
			this.path = path;
		}

		@Override
		public Long call() throws Exception {
			if (path != null)
				return runApplication(path);
			return runApplication(input);
		}
	}

	/**
	 * Method to run in a thread handling one request from the queue at a time.
	 * 
	 * LOCAL_MODE means you only send over a path to the input file.
	 * DEFAULT_MODE means the actual bytes of the file are sent.
	 */
	private static void doFuzzerRuns() {
		if (verbosity > 1) 
			System.out.println("Fuzzer runs handler thread started.");
		
		while (true) {
			try {
				FuzzRequest request = requestQueue.poll();
				if (request != null) {
					if (verbosity > 1)
						System.out.println("Handling request 1 of " + (requestQueue.size()+1));

					InputStream is = request.clientSocket.getInputStream();
					OutputStream os = request.clientSocket.getOutputStream();

					Mem.clear();
					byte result = STATUS_CRASH;
					ApplicationCall appCall = null;

					// read the mode (local or default)
					byte mode = (byte) is.read();
					
					/* LOCAL MODE */
					if (mode == LOCAL_MODE) {
						if (verbosity > 1) 
							System.out.println("Handling request in LOCAL MODE.");
						
						// read the length of the path (integer)
						int pathlen = is.read() | is.read() << 8 | is.read() << 16 | is.read() << 24;
						if (verbosity > 2)
							System.out.println("Path len = " + pathlen);
						
						if (pathlen < 0) {
							if (verbosity > 1)
								System.err.println("Failed to read path length");
							result = STATUS_COMM_ERROR;
						} else {
						
							// read the path
							byte input[] = new byte[pathlen];
							int read = 0;
							while (read < pathlen) {			
								if (is.available() > 0) {
									input[read++] = (byte) is.read();
								} else {
									if (verbosity > 1) {
										System.err.println("No input available from stream, strangely, breaking.");
										result = STATUS_COMM_ERROR;
										break;
									}
								}
							}
							String path = new String(input);
							if (verbosity > 1)
								System.out.println("Received path: " + path);
							
							appCall = new ApplicationCall(path);
						}
						
					/* DEFAULT MODE */
					} else {
						if (verbosity > 1) 
							System.out.println("Handling request in DEFAULT MODE.");
						
						// read the size of the input file (integer)
						int filesize = is.read() | is.read() << 8 | is.read() << 16 | is.read() << 24;
						if (verbosity > 2)
							System.out.println("File size = " + filesize);

						if (filesize < 0) {
							if (verbosity > 1)
								System.err.println("Failed to read file size");
							result = STATUS_COMM_ERROR;
						} else {

							// read the input file
							byte input[] = new byte[filesize];
							int read = 0;
							while (read < filesize) {			
								if (is.available() > 0) {
									input[read++] = (byte) is.read();
								} else {
									if (verbosity > 1) {
										System.err.println("No input available from stream, strangely");
										System.err.println("Appending a 0");
									}
									input[read++] = 0;
								}
							}
							
							appCall = new ApplicationCall(input);
						}
					}
					
					if (result != STATUS_COMM_ERROR && appCall != null) {

						// run app with input
						ExecutorService executor = Executors.newSingleThreadExecutor();
						Future<Long> future = executor.submit(appCall);

						try {
							if (verbosity > 1)
								System.out.println("Started...");
							future.get(timeout, TimeUnit.MILLISECONDS);
							result = STATUS_SUCCESS;
							if (verbosity > 1)
								System.out.println("Finished!");
						} catch (TimeoutException te) {
							future.cancel(true);
							if (verbosity > 1) 
								System.out.println("Time-out!");
							result = STATUS_TIMEOUT;
						} catch (Throwable e) {
							future.cancel(true);
							if (e.getCause() instanceof RuntimeException) {
								if (verbosity > 1) 
									System.out.println("RuntimeException thrown!");
							} else if (e.getCause() instanceof Error) {
								if (verbosity > 1) 
									System.out.println("Error thrown!");
							} else {
								if (verbosity > 1) 
									System.out.println("Uncaught throwable!");
							}
							e.printStackTrace();
						}
						executor.shutdownNow();
					}
					
					if (verbosity > 1)
						System.out.println("Result: " + result);
					
					if (verbosity > 2)
						Mem.print();

					// send back status
					os.write(result);

					// send back "shared memory" over TCP
					os.write(Mem.mem, 0, Mem.mem.length);

					// close connection
					os.flush();
					request.clientSocket.shutdownOutput();
					request.clientSocket.shutdownInput();
					request.clientSocket.setSoLinger(true, 100000);
					request.clientSocket.close();
					if (verbosity > 1) 
						System.out.println("Connection closed.");

				} else {
					// if no request, close your eyes for a bit
					Thread.sleep(100);
				}
			} catch (SocketException se) {
				// Connection was reset, most probably means AFL process was killed.
				if (verbosity > 1) 
					System.out.println("Connection reset.");
			} catch (Exception e) {
				e.printStackTrace();
				throw new RuntimeException("Exception running fuzzed input");
			}
		}
	}

	public static void main(String args[]) {

		/**
		 * Parse command line parameters: load the main class,
		 * grab -port option and store command-line parameters for fuzzing runs.
		 */
		if (args.length < 1) {
			System.err.println("Usage: java edu.cmu.sv.kelinci.Kelinci [-v N] [-p N] [-t N] package.ExampleMain <args>");
			return;
		}

		port = DEFAULT_PORT;
		timeout = DEFAULT_TIMEOUT;
		verbosity = DEFAULT_VERBOSITY;
		
		int curArg = 0;
		while (args.length > curArg) {
			if (args[curArg].equals("-p") || args[curArg].equals("-port")) {
				port = Integer.parseInt(args[curArg+1]);
				curArg += 2;
			} else if (args[curArg].equals("-v") || args[curArg].equals("-verbosity")) {
				verbosity = Integer.parseInt(args[curArg+1]);
				curArg += 2;
			} else if (args[curArg].equals("-t") || args[curArg].equals("-timeout")) {
				timeout = Long.parseLong(args[curArg+1]);
				curArg += 2;
			} else {
				break;
			}
		}
		String mainClass = args[curArg];
		targetArgs = Arrays.copyOfRange(args, curArg+1, args.length);
		
		/**
		 * Check if at least one of the target parameters is @@
		 */
		boolean present = false;
		for (int i = 0; i < targetArgs.length; i++) {
			if (targetArgs[i].equals("@@")) {
				present = true;
				break;
			}
		}
		if (!present) {
			System.err.println("Error: none of the target application parameters is @@");
			System.exit(1);
		}
		
		/**
		 * Redirect target program output to /dev/null if requested.
		 */
		if (verbosity <= 0) {
			PrintStream nullStream = new PrintStream(new NullOutputStream());
			System.setOut(nullStream);
			System.setErr(nullStream);
		}

		ClassLoader classloader = Thread.currentThread().getContextClassLoader();
		try {
			Class<?> target = classloader.loadClass(mainClass);
			targetMain = target.getMethod("main", String[].class);
		} catch (ClassNotFoundException e) {
			System.err.println("Main class not found: " + mainClass);
			return;
		} catch (NoSuchMethodException e) {
			System.err.println("No main method found in class: " + mainClass);
			return;
		} catch (SecurityException e) {
			System.err.println("Main method in class not accessible: " + mainClass);
			return;
		}

		/**
		 * Create the tmp file to serve as input file to the program.
		 */
		try {
			tmpfile = File.createTempFile("kelinci-input", "");
			tmpfile.deleteOnExit();
		} catch (IOException ioe) {
			throw new RuntimeException("Error creating tmp file");
		}

		/**
		 * Start the server thread
		 */
		Thread server = new Thread(new Runnable() {
			@Override
			public void run() {
				runServer();
			}
		});
		server.start();

		/**
		 * Handle requests for fuzzer runs in separate thread.
		 */
		Thread fuzzerRuns = new Thread(new Runnable() {
			@Override
			public void run() {
				doFuzzerRuns();
			}
		});
		fuzzerRuns.start();
	}
	
	/**
	 * Stream to /dev/null. Used to redirect output of target program.
	 * 
	 * I know something like this is also in Apache Commons IO, but if I include it here, 
	 * we don't need any libs on the classpath when running the Kelinci server.
	 * 
	 * @author rodykers
	 *
	 */
	private static class NullOutputStream extends ByteArrayOutputStream {

	    @Override
	    public void write(int b) {}

	    @Override
	    public void write(byte[] b, int off, int len) {}

	    @Override
	    public void writeTo(OutputStream out) throws IOException {}
	  }
}