/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.kafka.common.security.authenticator;

import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.errors.UnsupportedVersionException;
import org.apache.kafka.common.network.Authenticator;
import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.network.Mode;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.network.TransportLayer;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.protocol.Protocol;
import org.apache.kafka.common.protocol.types.SchemaException;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.requests.AbstractResponse;
import org.apache.kafka.common.requests.ApiVersionsResponse;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.SaslHandshakeRequest;
import org.apache.kafka.common.requests.SaslHandshakeResponse;
import org.apache.kafka.common.security.auth.AuthCallbackHandler;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.PrincipalBuilder;
import org.apache.kafka.common.security.kerberos.KerberosName;
import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.ScramMechanism;
import org.apache.kafka.common.security.scram.ScramServerCallbackHandler;
import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.auth.Subject;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.security.Principal;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class SaslServerAuthenticator implements Authenticator {

    private static final Logger LOG = LoggerFactory.getLogger(SaslServerAuthenticator.class);

    public enum SaslState {
        GSSAPI_OR_HANDSHAKE_REQUEST, HANDSHAKE_REQUEST, AUTHENTICATE, COMPLETE, FAILED
    }

    private final String node;
    private final JaasContext jaasContext;
    private final Subject subject; //表示用于身份认证的主体
    private final KerberosShortNamer kerberosNamer;
    private final int maxReceiveSize;
    private final String host;
    private final CredentialCache credentialCache;

    // Current SASL state
    private SaslState saslState = SaslState.GSSAPI_OR_HANDSHAKE_REQUEST;
    // Next SASL state to be set when outgoing writes associated with the current SASL state complete
    private SaslState pendingSaslState = null;
    private SaslServer saslServer; // javax.security包中提供的sasl身份认证服务端接口
    private String saslMechanism;
    private AuthCallbackHandler callbackHandler;

    // assigned in `configure`
    private TransportLayer transportLayer;
    private Set<String> enabledMechanisms;
    private Map<String, ?> configs;

    // buffers used in `authenticate`
    private NetworkReceive netInBuffer;
    private Send netOutBuffer;

    public SaslServerAuthenticator(String node, JaasContext jaasContext, final Subject subject, KerberosShortNamer kerberosNameParser, String host, int maxReceiveSize, CredentialCache credentialCache) throws IOException {
        if (subject == null)
            throw new IllegalArgumentException("subject cannot be null");
        this.node = node;
        this.jaasContext = jaasContext;
        this.subject = subject;
        this.kerberosNamer = kerberosNameParser;
        this.maxReceiveSize = maxReceiveSize;
        this.host = host;
        this.credentialCache = credentialCache;
    }
    // 用来初始化enableMechanisms
    public void configure(TransportLayer transportLayer, PrincipalBuilder principalBuilder, Map<String, ?> configs) {
        this.transportLayer = transportLayer;
        this.configs = configs;
        List<String> enabledMechanisms = (List<String>) this.configs.get(SaslConfigs.SASL_ENABLED_MECHANISMS);
        if (enabledMechanisms == null || enabledMechanisms.isEmpty())
            throw new IllegalArgumentException("No SASL mechanisms are enabled");
        this.enabledMechanisms = new HashSet<>(enabledMechanisms);
    }

    private void createSaslServer(String mechanism) throws IOException {
        this.saslMechanism = mechanism;
        if (!ScramMechanism.isScram(mechanism))
            callbackHandler = new SaslServerCallbackHandler(jaasContext, kerberosNamer);
        else
            callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class));
        callbackHandler.configure(configs, Mode.SERVER, subject, saslMechanism);
        if (mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) {
            saslServer = createSaslKerberosServer(callbackHandler, configs, subject);
        } else {
            try {
                saslServer = Subject.doAs(subject, new PrivilegedExceptionAction<SaslServer>() {
                    public SaslServer run() throws SaslException {
                        // 调用createSaslServer
                        return Sasl.createSaslServer(saslMechanism, "kafka", host, configs, callbackHandler);
                    }
                });
            } catch (PrivilegedActionException e) {
                throw new SaslException("Kafka Server failed to create a SaslServer to interact with a client during session authentication", e.getCause());
            }
        }
    }

    private SaslServer createSaslKerberosServer(final AuthCallbackHandler saslServerCallbackHandler, final Map<String, ?> configs, Subject subject) throws IOException {
        // server is using a JAAS-authenticated subject: determine service principal name and hostname from kafka server's subject.
        final String servicePrincipal = SaslClientAuthenticator.firstPrincipal(subject);
        KerberosName kerberosName;
        try {
            kerberosName = KerberosName.parse(servicePrincipal);
        } catch (IllegalArgumentException e) {
            throw new KafkaException("Principal has name with unexpected format " + servicePrincipal);
        }
        final String servicePrincipalName = kerberosName.serviceName();
        final String serviceHostname = kerberosName.hostName();

        LOG.debug("Creating SaslServer for {} with mechanism {}", kerberosName, saslMechanism);

        // As described in http://docs.oracle.com/javase/8/docs/technotes/guides/security/jgss/jgss-features.html:
        // "To enable Java GSS to delegate to the native GSS library and its list of native mechanisms,
        // set the system property "sun.security.jgss.native" to true"
        // "In addition, when performing operations as a particular Subject, for example, Subject.doAs(...)
        // or Subject.doAsPrivileged(...), the to-be-used GSSCredential should be added to Subject's
        // private credential set. Otherwise, the GSS operations will fail since no credential is found."
        boolean usingNativeJgss = Boolean.getBoolean("sun.security.jgss.native");
        if (usingNativeJgss) {
            try {
                GSSManager manager = GSSManager.getInstance();
                // This Oid is used to represent the Kerberos version 5 GSS-API mechanism. It is defined in
                // RFC 1964.
                Oid krb5Mechanism = new Oid("1.2.840.113554.1.2.2");
                GSSName gssName = manager.createName(servicePrincipalName + "@" + serviceHostname, GSSName.NT_HOSTBASED_SERVICE);
                GSSCredential cred = manager.createCredential(gssName, GSSContext.INDEFINITE_LIFETIME, krb5Mechanism, GSSCredential.ACCEPT_ONLY);
                subject.getPrivateCredentials().add(cred);
            } catch (GSSException ex) {
                LOG.warn("Cannot add private credential to subject; clients authentication may fail", ex);
            }
        }

        try {
            return Subject.doAs(subject, new PrivilegedExceptionAction<SaslServer>() {
                public SaslServer run() throws SaslException {
                    return Sasl.createSaslServer(saslMechanism, servicePrincipalName, serviceHostname, configs, saslServerCallbackHandler);
                }
            });
        } catch (PrivilegedActionException e) {
            throw new SaslException("Kafka Server failed to create a SaslServer to interact with a client during session authentication", e.getCause());
        }
    }

    /**
     * Evaluates client responses via `SaslServer.evaluateResponse` and returns the issued challenge to the client until
     * authentication succeeds or fails.
     *
     * The messages are sent and received as size delimited bytes that consists of a 4 byte network-ordered size N
     * followed by N bytes representing the opaque payload.
     */
    // 处理客户端发来的SaslHandshakeRequest
    // 验证服务端是否支持指定的sasl机制
    // 向客户端发送response
    // 验证失败,则返回响应的challenge信息
    // 成功,不返回任何数据
    public void authenticate() throws IOException {
        // 缓冲区有未发送的数据,先发送
        if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps())
            return;

        if (saslServer != null && saslServer.isComplete()) {
            setSaslState(SaslState.COMPLETE);
            return;
        }
        // 创建缓冲区,从socketchannel中读取数据
        if (netInBuffer == null) netInBuffer = new NetworkReceive(maxReceiveSize, node);

        netInBuffer.readFrom(transportLayer);

        // 检测是否读取了一个完整信息
        if (netInBuffer.complete()) {
            netInBuffer.payload().rewind();
            byte[] clientToken = new byte[netInBuffer.payload().remaining()];
            netInBuffer.payload().get(clientToken, 0, clientToken.length);
            netInBuffer = null; // reset the networkReceive as we read all the data.
            try {
                switch (saslState) {
                    // 处理握手信息
                    case HANDSHAKE_REQUEST:
                        handleKafkaRequest(clientToken);
                        break;
                    case GSSAPI_OR_HANDSHAKE_REQUEST:
                        if (handleKafkaRequest(clientToken))
                            break;
                        // For default GSSAPI, fall through to authenticate using the client token as the first GSSAPI packet.
                        // This is required for interoperability with 0.9.0.x clients which do not send handshake request
                    // 处理客户端发来的respose信息 PlainSaslServer evaluateResponse
                    case AUTHENTICATE:
                        byte[] response = saslServer.evaluateResponse(clientToken);
                        if (response != null) {
                            netOutBuffer = new NetworkSend(node, ByteBuffer.wrap(response));
                            flushNetOutBufferAndUpdateInterestOps();
                        }
                        // When the authentication exchange is complete and no more tokens are expected from the client,
                        // update SASL state. Current SASL state will be updated when outgoing writes to the client complete.
                        if (saslServer.isComplete())
                            setSaslState(SaslState.COMPLETE);
                        break;
                    default:
                        break;
                }
            } catch (Exception e) {
                setSaslState(SaslState.FAILED);
                throw new IOException(e);
            }
        }
    }

    public Principal principal() {
        return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, saslServer.getAuthorizationID());
    }

    public boolean complete() {
        return saslState == SaslState.COMPLETE;
    }

    public void close() throws IOException {
        if (saslServer != null)
            saslServer.dispose();
        if (callbackHandler != null)
            callbackHandler.close();
    }

    private void setSaslState(SaslState saslState) {
        if (netOutBuffer != null && !netOutBuffer.completed())
            pendingSaslState = saslState;
        else {
            this.pendingSaslState = null;
            this.saslState = saslState;
            LOG.debug("Set SASL server state to {}", saslState);
        }
    }

    private boolean flushNetOutBufferAndUpdateInterestOps() throws IOException {
        boolean flushedCompletely = flushNetOutBuffer();
        if (flushedCompletely) {
            transportLayer.removeInterestOps(SelectionKey.OP_WRITE);
            if (pendingSaslState != null)
                setSaslState(pendingSaslState);
        } else
            transportLayer.addInterestOps(SelectionKey.OP_WRITE);
        return flushedCompletely;
    }

    private boolean flushNetOutBuffer() throws IOException {
        if (!netOutBuffer.completed())
            netOutBuffer.writeTo(transportLayer);
        return netOutBuffer.completed();
    }

    private boolean handleKafkaRequest(byte[] requestBytes) throws IOException, AuthenticationException {
        boolean isKafkaRequest = false;
        String clientMechanism = null;
        try {
            ByteBuffer requestBuffer = ByteBuffer.wrap(requestBytes);
            RequestHeader requestHeader = RequestHeader.parse(requestBuffer);
            ApiKeys apiKey = ApiKeys.forId(requestHeader.apiKey());
            // A valid Kafka request header was received. SASL authentication tokens are now expected only
            // following a SaslHandshakeRequest since this is not a GSSAPI client token from a Kafka 0.9.0.x client.

            // 状态切换
            setSaslState(SaslState.HANDSHAKE_REQUEST);
            isKafkaRequest = true;
            // 检测apikye、version是否合法

            if (!Protocol.apiVersionSupported(requestHeader.apiKey(), requestHeader.apiVersion())) {
                if (apiKey == ApiKeys.API_VERSIONS)
                    sendKafkaResponse(ApiVersionsResponse.unsupportedVersionSend(node, requestHeader));
                else
                    throw new UnsupportedVersionException("Version " + requestHeader.apiVersion() + " is not supported for apiKey " + apiKey);
            } else {
                AbstractRequest request = AbstractRequest.getRequest(requestHeader.apiKey(), requestHeader.apiVersion(),
                        requestBuffer).request;

                LOG.debug("Handle Kafka request {}", apiKey);
                switch (apiKey) {
                    case API_VERSIONS:
                        handleApiVersionsRequest(requestHeader);
                        break;
                    case SASL_HANDSHAKE:
                        clientMechanism = handleHandshakeRequest(requestHeader, (SaslHandshakeRequest) request);
                        break;
                    default:
                        throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake.");
                }
            }
        } catch (SchemaException | IllegalArgumentException e) {
            if (saslState == SaslState.GSSAPI_OR_HANDSHAKE_REQUEST) {
                // SchemaException is thrown if the request is not in Kafka format. IllegalArgumentException is thrown
                // if the API key is invalid. For compatibility with 0.9.0.x where the first packet is a GSSAPI token
                // starting with 0x60, revert to GSSAPI for both these exceptions.
                if (LOG.isDebugEnabled()) {
                    StringBuilder tokenBuilder = new StringBuilder();
                    for (byte b : requestBytes) {
                        tokenBuilder.append(String.format("%02x", b));
                        if (tokenBuilder.length() >= 20)
                             break;
                    }
                    LOG.debug("Received client packet of length {} starting with bytes 0x{}, process as GSSAPI packet", requestBytes.length, tokenBuilder);
                }
                if (enabledMechanisms.contains(SaslConfigs.GSSAPI_MECHANISM)) {
                    LOG.debug("First client packet is not a SASL mechanism request, using default mechanism GSSAPI");
                    clientMechanism = SaslConfigs.GSSAPI_MECHANISM;
                } else
                    throw new UnsupportedSaslMechanismException("Exception handling first SASL packet from client, GSSAPI is not supported by server", e);
            } else
                throw e;
        }
        if (clientMechanism != null) {
            // sasl机制检测通过,创建PlainSaslServer
            createSaslServer(clientMechanism);
            setSaslState(SaslState.AUTHENTICATE);
        }
        return isKafkaRequest;
    }

    private String handleHandshakeRequest(RequestHeader requestHeader, SaslHandshakeRequest handshakeRequest) throws IOException, UnsupportedSaslMechanismException {
        String clientMechanism = handshakeRequest.mechanism();
        if (enabledMechanisms.contains(clientMechanism)) {
            LOG.debug("Using SASL mechanism '{}' provided by client", clientMechanism);
            sendKafkaResponse(requestHeader, new SaslHandshakeResponse(Errors.NONE, enabledMechanisms));
            return clientMechanism;
        } else {
            LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism);
            sendKafkaResponse(requestHeader, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms));
            throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism);
        }
    }

    private void handleApiVersionsRequest(RequestHeader requestHeader) throws IOException, UnsupportedSaslMechanismException {
        sendKafkaResponse(requestHeader, ApiVersionsResponse.API_VERSIONS_RESPONSE);
    }

    private void sendKafkaResponse(RequestHeader requestHeader, AbstractResponse response) throws IOException {
        sendKafkaResponse(response.toSend(node, requestHeader));
    }

    private void sendKafkaResponse(Send send) throws IOException {
        netOutBuffer = send;
        flushNetOutBufferAndUpdateInterestOps();
    }
}