/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.qpid.jms.provider.amqp; import java.util.function.Function; import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecurityException; import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecuritySaslException; import org.apache.qpid.jms.sasl.Mechanism; import org.apache.qpid.jms.sasl.SaslSecurityRuntimeException; import org.apache.qpid.proton.engine.Sasl; import org.apache.qpid.proton.engine.Sasl.SaslOutcome; import org.apache.qpid.proton.engine.Transport; /** * Manage the SASL authentication process */ public class AmqpSaslAuthenticator { private final Function<String[], Mechanism> mechanismFinder; private Mechanism mechanism; private boolean complete; private ProviderConnectionSecurityException failureCause; /** * Create the authenticator and initialize it. * * @param mechanismFinder * An object that is used to locate the most correct SASL Mechanism to perform the authentication. */ public AmqpSaslAuthenticator(Function<String[], Mechanism> mechanismFinder) { this.mechanismFinder = mechanismFinder; } public boolean isComplete() { return complete; } public ProviderConnectionSecurityException getFailureCause() { return failureCause; } public boolean wasSuccessful() throws IllegalStateException { if (complete) { return failureCause == null; } else { throw new IllegalStateException("Authentication has not completed yet."); } } //----- SaslListener implementation --------------------------------------// public void handleSaslMechanisms(Sasl sasl, Transport transport) { try { String[] remoteMechanisms = sasl.getRemoteMechanisms(); if (remoteMechanisms != null && remoteMechanisms.length != 0) { try { mechanism = mechanismFinder.apply(remoteMechanisms); } catch (SaslSecurityRuntimeException ssre){ recordFailure("Could not find a suitable SASL mechanism. " + ssre.getMessage(), ssre); return; } byte[] response = mechanism.getInitialResponse(); if (response != null) { sasl.send(response, 0, response.length); } sasl.setMechanisms(mechanism.getName()); } } catch (Throwable error) { recordFailure("Exception while processing SASL init: " + error.getMessage(), error); } } public void handleSaslChallenge(Sasl sasl, Transport transport) { try { if (sasl.pending() >= 0) { byte[] challenge = new byte[sasl.pending()]; sasl.recv(challenge, 0, challenge.length); byte[] response = mechanism.getChallengeResponse(challenge); if (response != null) { sasl.send(response, 0, response.length); } } } catch (Throwable error) { recordFailure("Exception while processing SASL step: " + error.getMessage(), error); } } public void handleSaslOutcome(Sasl sasl, Transport transport) { try { switch (sasl.getState()) { case PN_SASL_FAIL: handleSaslFail(sasl); break; case PN_SASL_PASS: handleSaslCompletion(sasl); break; default: break; } } catch (Throwable error) { recordFailure(error.getMessage(), error); } } //----- Internal support methods -----------------------------------------// private void handleSaslFail(Sasl sasl) { StringBuilder message = new StringBuilder("Client failed to authenticate"); if (mechanism != null) { message.append(" using SASL: ").append(mechanism.getName()); if (mechanism.getAdditionalFailureInformation() != null) { message.append(" (").append(mechanism.getAdditionalFailureInformation()).append(")"); } } SaslOutcome outcome = sasl.getOutcome(); if (outcome.equals(SaslOutcome.PN_SASL_TEMP)) { message.append(", due to temporary system error."); } recordFailure(message.toString(), null, outcome.getCode()); } private void handleSaslCompletion(Sasl sasl) { try { if (sasl.pending() != 0) { byte[] additionalData = new byte[sasl.pending()]; sasl.recv(additionalData, 0, additionalData.length); mechanism.getChallengeResponse(additionalData); } mechanism.verifyCompletion(); complete = true; } catch (Throwable error) { recordFailure("Exception while processing SASL exchange completion: " + error.getMessage(), error); } } private void recordFailure(String message, Throwable cause) { recordFailure(message, cause, SaslOutcome.PN_SASL_NONE.getCode()); } private void recordFailure(String message, Throwable cause, int outcome) { failureCause = new ProviderConnectionSecuritySaslException(message, outcome, cause); complete = true; } }