* Copyright 2015 BlackBerry, Limited.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *     http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * See the License for the specific language governing permissions and
 * limitations under the License.
package com.blackberry.bdp.krackle.auth;

import com.blackberry.bdp.krackle.exceptions.MissingConfigurationException;
import com.blackberry.bdp.krackle.exceptions.InvalidConfigurationTypeException;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.Socket;
import java.util.Arrays;

import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Map;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslPlainTextAuthenticator implements Authenticator{

	public enum SaslState {



	private static final Logger LOG = LoggerFactory.getLogger(SaslPlainTextAuthenticator.class);

	// Configurable Items
	private String hostname;
	private Subject subject;
	private String servicePrincipal;

	private SaslClient saslClient;
	private String clientPrincipal;
	private boolean configured;

	private final Socket socket;
	private final DataInputStream inStream;
	private final DataOutputStream outStream;

	private static final byte[] EMPTY = new byte[0];

	private SaslState saslState;

	 * Will create a socket based on host name and port
	 * @param hostname
	 * @param port
	 * @throws IOException
	 * @throws SaslException
	public SaslPlainTextAuthenticator(String hostname, int port)
		 throws IOException, SaslException {
		this(new Socket(hostname, port));

	 * Will use an existing socket
	 * @param socket
	 * @throws IOException
	 * @throws SaslException
	public SaslPlainTextAuthenticator(Socket socket)
		 throws IOException, SaslException {
		this.socket = socket;
		this.inStream = new DataInputStream(socket.getInputStream());
		this.outStream = new DataOutputStream(socket.getOutputStream());
		this.configured = false;
		saslState = SaslState.INITIAL;

	public void configure(Map<String, ?> configs) throws
		 MissingConfigurationException, InvalidConfigurationTypeException, SaslException {

		// No longer required to be specified in config map allows config map to be shared
		hostname = socket.getInetAddress().getHostName();

		if (!configs.containsKey("subject")) {
			throw new MissingConfigurationException("`subject` not defined in configration");
		} else if (!configs.get("subject").getClass().equals(Subject.class)) {
			String type = Subject.class.getCanonicalName();
			throw new InvalidConfigurationTypeException("`subject` is not a " + type);
		} else {
			subject = (Subject) configs.get("subject");

		if (!configs.containsKey("servicePrincipal")) {
			throw new MissingConfigurationException("`servicePrincipal` not defined in configration");
		} else if (!configs.get("servicePrincipal").getClass().equals(String.class)) {
			String type = String.class.getCanonicalName();
			throw new InvalidConfigurationTypeException("`servicePrincipal` is not a " + type);
		} else {
			servicePrincipal = (String) configs.get("servicePrincipal");

		if (!configs.containsKey("clientPrincipal")) {
			throw new MissingConfigurationException("`clientPrincipal` not defined in configration");
		} else if (!configs.get("clientPrincipal").getClass().equals(String.class)) {
			String type = String.class.getCanonicalName();
			throw new InvalidConfigurationTypeException("`clientPrincipal` is not a " + type);
		} else {
			clientPrincipal = (String) configs.get("clientPrincipal");
		this.saslClient = createSaslClient();
		configured = true;
		LOG.info("authenticator has been configured");

	private SaslClient createSaslClient() throws SaslException {
		try {
			return Subject.doAs(subject, new PrivilegedExceptionAction<SaslClient>() {
				public SaslClient run() throws SaslException {
					String[] mechs = {"GSSAPI"};
					LOG.info("Creating SaslClient: client={}; service={}; serviceHostname={}; mechs={}",
						 clientPrincipal, servicePrincipal, hostname, Arrays.toString(mechs));
					return Sasl.createSaslClient(mechs, clientPrincipal, servicePrincipal, hostname, null,
						 new ClientCallbackHandler());

		} catch (PrivilegedActionException e) {
			throw new SaslException("Failed to create SaslClient", e.getCause());

	 * Sends an empty message to the server to initiate the authentication process. It then evaluates server challenges
	 * via `SaslClient.evaluateChallenge` and returns client responses until authentication succeeds or fails.
	 * The messages are sent and received as size delimited bytes that consists of a 4 byte network-ordered size N
	 * followed by N bytes representing the opaque payload.
	 * @throws java.io.IOException
	public void authenticate() throws IOException {
		if (!configured) {
			throw new IOException("authentication attempted on unconfigured authenticator");
		while (!saslClient.isComplete()) {
			switch (saslState) {
				case INITIAL:
					LOG.debug("saslClient has initial response? {}",
					saslState = SaslState.INTERMEDIATE;
					LOG.debug("sent initial empty sasl token");
					byte[] challenge;
					LOG.debug("in intermediate");
					int length = inStream.readInt();
					LOG.debug("in intermediate - read  int, length of response is {}", length);
					challenge = new byte[length];
					LOG.debug("read response");
					if (saslClient.isComplete()) {
						LOG.debug("complete sasl state detected in intermediate");
						saslState = SaslState.COMPLETE;
				case COMPLETE:
				case FAILED:
					throw new IOException("SASL handshake failed");
		LOG.debug("authentication complete");

	private void sendSaslToken(byte[] serverToken) throws IOException {
		if (!saslClient.isComplete()) {
			try {
				byte[] saslToken = createSaslToken(serverToken);
				if (saslToken != null) {
					LOG.debug("sending sasl token of length: {}", saslToken.length);
					LOG.debug("sent sasl token of length: {}", saslToken.length);
			} catch (IOException e) {
				saslState = SaslState.FAILED;
				throw e;
		} else {
			LOG.warn("attempting to send sasl token to a completed sasl client");

	private byte[] createSaslToken(final byte[] saslToken) throws SaslException {
		if (saslToken == null) {
			throw new SaslException("Error authenticating with the Kafka Broker: received a nul saslToken.");
		try {
			return Subject.doAs(subject, new PrivilegedExceptionAction<byte[]>() {
				public byte[] run() throws SaslException {
					LOG.debug("evaluating challenge of length {} to {}",
					byte[] evaluation = saslClient.evaluateChallenge(saslToken);
					LOG.debug("evaluation length is {}", evaluation.length);
					return evaluation;
		} catch (PrivilegedActionException e) {
			String error = "An error: (" + e + ") occurred when evaluating SASL token received from the Kafka Broker.";
			// Try to provide hints to use about what went wrong so they can fix their configuration.
			// TODO: introspect about e: look for GSS information.
			final String unknownServerErrorText
				 = "(Mechanism level: Server not found in Kerberos database (7) - UNKNOWN_SERVER)";
			if (e.toString().contains(unknownServerErrorText)) {
				error += " This may be caused by Java's being unable to resolve the Kafka Broker's"
					 + " hostname correctly. You may want to try to adding"
					 + " '-Dsun.net.spi.nameservice.provider.1=dns,sun' to your client's JVMFLAGS environment."
					 + " Users must configure FQDN of kafka brokers when authenticating using SASL and"
					 + " `socketChannel.socket().getInetAddress().getHostName()` must match the hostname in `principal/[email protected]`";
			error += " Kafka Client will go to AUTH_FAILED state.";
			//Unwrap the SaslException inside `PrivilegedActionException`
			throw new SaslException(error, e.getCause());

	private static class ClientCallbackHandler implements CallbackHandler {

		public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
			for (Callback callback : callbacks) {
				LOG.info("callback {} received", callback.toString());
				if (callback instanceof NameCallback) {
					NameCallback nc = (NameCallback) callback;
				} else {
					if (callback instanceof PasswordCallback) {
						// Call `setPassword` once we support obtaining a password from the user and update message below
						throw new UnsupportedCallbackException(callback, "Could not login: the client is being asked for a password, but the Kafka"
							 + " client code does not currently support obtaining a password from the user."
							 + " Make sure -Djava.security.auth.login.config property passed to JVM and"
							 + " the client is configured to use a ticket cache (using"
							 + " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using"
							 + " FQDN of the Kafka broker you are trying to connect to.");
					} else {
						if (callback instanceof RealmCallback) {
							RealmCallback rc = (RealmCallback) callback;
						} else {
							if (callback instanceof AuthorizeCallback) {
								AuthorizeCallback ac = (AuthorizeCallback) callback;
								String authId = ac.getAuthenticationID();
								String authzId = ac.getAuthorizationID();
								if (ac.isAuthorized()) {
							} else {
								throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback");


	public boolean complete() {
		return saslState == SaslState.COMPLETE;

	public void close() throws IOException {

	public Socket getSocket() {
		return socket;

	public boolean configured() {
		return configured;