/* * Copyright 2014-2020 Real Logic Limited. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.aeron.archive; import io.aeron.*; import io.aeron.archive.client.AeronArchive; import io.aeron.archive.client.ArchiveException; import io.aeron.driver.MediaDriver; import io.aeron.driver.ThreadingMode; import io.aeron.security.Authenticator; import io.aeron.security.AuthenticatorSupplier; import io.aeron.security.CredentialsSupplier; import io.aeron.security.SessionProxy; import io.aeron.test.MediaDriverTestWatcher; import io.aeron.test.TestMediaDriver; import io.aeron.test.Tests; import org.agrona.CloseHelper; import org.agrona.SystemUtil; import org.agrona.collections.MutableLong; import org.agrona.concurrent.status.CountersReader; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.extension.RegisterExtension; import java.io.File; import static io.aeron.archive.Common.*; import static io.aeron.archive.codecs.SourceLocation.LOCAL; import static io.aeron.security.NullCredentialsSupplier.NULL_CREDENTIAL; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.spy; public class ArchiveAuthenticationTest { private static final int RECORDED_STREAM_ID = 1033; private static final String RECORDED_CHANNEL = new ChannelUriStringBuilder() .media("udp") .endpoint("localhost:3333") .termLength(Common.TERM_LENGTH) .build(); private static final String CREDENTIALS_STRING = "username=\"admin\"|password=\"secret\""; private static final String CHALLENGE_STRING = "I challenge you!"; private static final String PRINCIPAL_STRING = "I am THE Principal!"; private final byte[] encodedCredentials = CREDENTIALS_STRING.getBytes(); private final byte[] encodedChallenge = CHALLENGE_STRING.getBytes(); private TestMediaDriver mediaDriver; private Archive archive; private Aeron aeron; private AeronArchive aeronArchive; private final String aeronDirectoryName = CommonContext.generateRandomDirName(); @RegisterExtension public final MediaDriverTestWatcher testWatcher = new MediaDriverTestWatcher(); @AfterEach public void after() { CloseHelper.closeAll(aeronArchive, aeron, archive, mediaDriver); archive.context().deleteDirectory(); mediaDriver.context().deleteDirectory(); } @Test @Timeout(10) public void shouldBeAbleToRecordWithDefaultCredentialsAndAuthenticator() { launchArchivingMediaDriver(null); connectClient(null); createRecording(); } @Test @Timeout(10) public void shouldBeAbleToRecordWithAuthenticateOnConnectRequestWithCredentials() { final MutableLong authenticatorSessionId = new MutableLong(-1L); final CredentialsSupplier credentialsSupplier = spy(new CredentialsSupplier() { public byte[] encodedCredentials() { return encodedCredentials; } public byte[] onChallenge(final byte[] encodedChallenge) { fail(); return null; } }); final Authenticator authenticator = spy(new Authenticator() { public void onConnectRequest(final long sessionId, final byte[] encodedCredentials, final long nowMs) { authenticatorSessionId.value = sessionId; assertEquals(CREDENTIALS_STRING, new String(encodedCredentials)); } public void onChallengeResponse(final long sessionId, final byte[] encodedCredentials, final long nowMs) { fail(); } public void onConnectedSession(final SessionProxy sessionProxy, final long nowMs) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.authenticate(PRINCIPAL_STRING.getBytes()); } public void onChallengedSession(final SessionProxy sessionProxy, final long nowMs) { fail(); } }); launchArchivingMediaDriver(() -> authenticator); connectClient(credentialsSupplier); assertEquals(aeronArchive.controlSessionId(), authenticatorSessionId.value); createRecording(); } @Test @Timeout(10) public void shouldBeAbleToRecordWithAuthenticateOnChallengeResponse() { final MutableLong authenticatorSessionId = new MutableLong(-1L); final CredentialsSupplier credentialsSupplier = spy(new CredentialsSupplier() { public byte[] encodedCredentials() { return NULL_CREDENTIAL; } public byte[] onChallenge(final byte[] encodedChallenge) { assertEquals(CHALLENGE_STRING, new String(encodedChallenge)); return encodedCredentials; } }); final Authenticator authenticator = spy(new Authenticator() { boolean challengeSuccessful = false; public void onConnectRequest(final long sessionId, final byte[] encodedCredentials, final long nowMs) { authenticatorSessionId.value = sessionId; assertEquals(0, encodedCredentials.length); } public void onChallengeResponse(final long sessionId, final byte[] encodedCredentials, final long nowMs) { assertEquals(sessionId, authenticatorSessionId.value); assertEquals(CREDENTIALS_STRING, new String(encodedCredentials)); challengeSuccessful = true; } public void onConnectedSession(final SessionProxy sessionProxy, final long nowMs) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.challenge(encodedChallenge); } public void onChallengedSession(final SessionProxy sessionProxy, final long nowMs) { if (challengeSuccessful) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.authenticate(PRINCIPAL_STRING.getBytes()); } } }); launchArchivingMediaDriver(() -> authenticator); connectClient(credentialsSupplier); assertEquals(aeronArchive.controlSessionId(), authenticatorSessionId.value); createRecording(); } @Test @Timeout(10) public void shouldNotBeAbleToConnectWithRejectOnConnectRequest() { final MutableLong authenticatorSessionId = new MutableLong(-1L); final CredentialsSupplier credentialsSupplier = spy(new CredentialsSupplier() { public byte[] encodedCredentials() { return NULL_CREDENTIAL; } public byte[] onChallenge(final byte[] encodedChallenge) { assertEquals(CHALLENGE_STRING, new String(encodedChallenge)); return encodedCredentials; } }); final Authenticator authenticator = spy(new Authenticator() { public void onConnectRequest(final long sessionId, final byte[] encodedCredentials, final long nowMs) { authenticatorSessionId.value = sessionId; assertEquals(0, encodedCredentials.length); } public void onChallengeResponse(final long sessionId, final byte[] encodedCredentials, final long nowMs) { fail(); } public void onConnectedSession(final SessionProxy sessionProxy, final long nowMs) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.reject(); } public void onChallengedSession(final SessionProxy sessionProxy, final long nowMs) { fail(); } }); launchArchivingMediaDriver(() -> authenticator); try { connectClient(credentialsSupplier); } catch (final ArchiveException ex) { assertEquals(ArchiveException.AUTHENTICATION_REJECTED, ex.errorCode()); return; } fail("should have seen exception"); } @Test @Timeout(10) public void shouldNotBeAbleToConnectWithRejectOnChallengeResponse() { final MutableLong authenticatorSessionId = new MutableLong(-1L); final CredentialsSupplier credentialsSupplier = spy(new CredentialsSupplier() { public byte[] encodedCredentials() { return NULL_CREDENTIAL; } public byte[] onChallenge(final byte[] encodedChallenge) { assertEquals(CHALLENGE_STRING, new String(encodedChallenge)); return encodedCredentials; } }); final Authenticator authenticator = spy(new Authenticator() { boolean challengeRespondedTo = false; public void onConnectRequest(final long sessionId, final byte[] encodedCredentials, final long nowMs) { authenticatorSessionId.value = sessionId; assertEquals(0, encodedCredentials.length); } public void onChallengeResponse(final long sessionId, final byte[] encodedCredentials, final long nowMs) { assertEquals(sessionId, authenticatorSessionId.value); assertEquals(CREDENTIALS_STRING, new String(encodedCredentials)); challengeRespondedTo = true; } public void onConnectedSession(final SessionProxy sessionProxy, final long nowMs) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.challenge(encodedChallenge); } public void onChallengedSession(final SessionProxy sessionProxy, final long nowMs) { if (challengeRespondedTo) { assertEquals(sessionProxy.sessionId(), authenticatorSessionId.value); sessionProxy.reject(); } } }); launchArchivingMediaDriver(() -> authenticator); try { connectClient(credentialsSupplier); } catch (final ArchiveException ex) { assertEquals(ArchiveException.AUTHENTICATION_REJECTED, ex.errorCode()); return; } fail("should have seen exception"); } private void connectClient(final CredentialsSupplier credentialsSupplier) { aeron = Aeron.connect( new Aeron.Context() .aeronDirectoryName(aeronDirectoryName)); aeronArchive = AeronArchive.connect( new AeronArchive.Context() .credentialsSupplier(credentialsSupplier) .aeron(aeron)); } private void launchArchivingMediaDriver(final AuthenticatorSupplier authenticatorSupplier) { mediaDriver = TestMediaDriver.launch( new MediaDriver.Context() .aeronDirectoryName(aeronDirectoryName) .termBufferSparseFile(true) .threadingMode(ThreadingMode.SHARED) .errorHandler(Tests::onError) .spiesSimulateConnection(false) .dirDeleteOnStart(true), testWatcher); archive = Archive.launch( new Archive.Context() .maxCatalogEntries(Common.MAX_CATALOG_ENTRIES) .aeronDirectoryName(aeronDirectoryName) .deleteArchiveOnStart(true) .archiveDir(new File(SystemUtil.tmpDirName(), "archive")) .fileSyncLevel(0) .authenticatorSupplier(authenticatorSupplier) .threadingMode(ArchiveThreadingMode.SHARED)); } private void createRecording() { final String messagePrefix = "Message-Prefix-"; final int messageCount = 10; final long subscriptionId = aeronArchive.startRecording(RECORDED_CHANNEL, RECORDED_STREAM_ID, LOCAL); try (Subscription subscription = aeron.addSubscription(RECORDED_CHANNEL, RECORDED_STREAM_ID); Publication publication = aeron.addPublication(RECORDED_CHANNEL, RECORDED_STREAM_ID)) { final CountersReader counters = aeron.countersReader(); final int counterId = Common.awaitRecordingCounterId(counters, publication.sessionId()); offer(publication, messageCount, messagePrefix); consume(subscription, messageCount, messagePrefix); final long currentPosition = publication.position(); awaitPosition(counters, counterId, currentPosition); } aeronArchive.stopRecording(subscriptionId); } }