package backtype.storm.security.auth;

import java.io.IOException;
import java.net.Socket;
import java.security.Principal;
import java.util.Map;

import javax.security.auth.Subject;
import javax.security.auth.login.Configuration;
import javax.security.sasl.SaslServer;

import org.apache.thrift7.TException;
import org.apache.thrift7.TProcessor;
import org.apache.thrift7.protocol.TBinaryProtocol;
import org.apache.thrift7.protocol.TProtocol;
import org.apache.thrift7.server.TServer;
import org.apache.thrift7.server.TThreadPoolServer;
import org.apache.thrift7.transport.TSaslServerTransport;
import org.apache.thrift7.transport.TServerSocket;
import org.apache.thrift7.transport.TSocket;
import org.apache.thrift7.transport.TTransport;
import org.apache.thrift7.transport.TTransportException;
import org.apache.thrift7.transport.TTransportFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Base class for SASL authentication plugin.
 */
public abstract class SaslTransportPlugin implements ITransportPlugin {
	protected Configuration login_conf;
	private static final Logger LOG = LoggerFactory
			.getLogger(SaslTransportPlugin.class);

	/**
	 * Invoked once immediately after construction
	 * 
	 * @param conf
	 *            Storm configuration
	 * @param login_conf
	 *            login configuration
	 */
	public void prepare(Map storm_conf, Configuration login_conf) {
		this.login_conf = login_conf;
	}

	public TServer getServer(int port, TProcessor processor)
			throws IOException, TTransportException {
		TTransportFactory serverTransportFactory = getServerTransportFactory();

		// define THsHaServer args
		// original: THsHaServer + TNonblockingServerSocket
		// option: TThreadPoolServer + TServerSocket
		TServerSocket serverTransport = new TServerSocket(port);
		TThreadPoolServer.Args server_args = new TThreadPoolServer.Args(
				serverTransport).processor(new TUGIWrapProcessor(processor))
				.minWorkerThreads(64).maxWorkerThreads(64)
				.protocolFactory(new TBinaryProtocol.Factory());
		if (serverTransportFactory != null)
			server_args.transportFactory(serverTransportFactory);

		// construct THsHaServer
		return new TThreadPoolServer(server_args);
	}

	/**
	 * All subclass must implement this method
	 * 
	 * @return
	 * @throws IOException
	 */
	protected abstract TTransportFactory getServerTransportFactory()
			throws IOException;

	/**
	 * Processor that pulls the SaslServer object out of the transport, and
	 * assumes the remote user's UGI before calling through to the original
	 * processor.
	 * 
	 * This is used on the server side to set the UGI for each specific call.
	 */
	private class TUGIWrapProcessor implements TProcessor {
		final TProcessor wrapped;

		TUGIWrapProcessor(TProcessor wrapped) {
			this.wrapped = wrapped;
		}

		public boolean process(final TProtocol inProt, final TProtocol outProt)
				throws TException {
			// populating request context
			ReqContext req_context = ReqContext.context();

			TTransport trans = inProt.getTransport();
			// Sasl transport
			TSaslServerTransport saslTrans = (TSaslServerTransport) trans;

			// remote address
			TSocket tsocket = (TSocket) saslTrans.getUnderlyingTransport();
			Socket socket = tsocket.getSocket();
			req_context.setRemoteAddress(socket.getInetAddress());

			// remote subject
			SaslServer saslServer = saslTrans.getSaslServer();
			String authId = saslServer.getAuthorizationID();
			Subject remoteUser = new Subject();
			remoteUser.getPrincipals().add(new User(authId));
			req_context.setSubject(remoteUser);

			// invoke service handler
			return wrapped.process(inProt, outProt);
		}
	}

	public static class User implements Principal {
		private final String name;

		public User(String name) {
			this.name = name;
		}

		/**
		 * Get the full name of the user.
		 */
		public String getName() {
			return name;
		}

		@Override
		public boolean equals(Object o) {
			if (this == o) {
				return true;
			} else if (o == null || getClass() != o.getClass()) {
				return false;
			} else {
				return (name.equals(((User) o).name));
			}
		}

		@Override
		public int hashCode() {
			return name.hashCode();
		}

		@Override
		public String toString() {
			return name;
		}
	}
}