/*
 * Copyright 2019-present HiveMQ GmbH
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.hivemq.bootstrap.netty.initializer;

import com.hivemq.bootstrap.netty.ChannelDependencies;
import com.hivemq.bootstrap.netty.FakeChannelPipeline;
import com.hivemq.configuration.service.FullConfigurationService;
import com.hivemq.configuration.service.entity.Listener;
import com.hivemq.configuration.service.entity.Tls;
import com.hivemq.configuration.service.entity.TlsTcpListener;
import com.hivemq.logging.EventLog;
import com.hivemq.security.ssl.SslFactory;
import io.netty.channel.Channel;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Future;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import util.DummyHandler;

import static com.hivemq.bootstrap.netty.ChannelHandlerNames.*;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.when;

public class TlsTcpChannelInitializerTest {

    @Mock
    private SocketChannel socketChannel;

    @Mock
    private Attribute<Listener> attribute;

    @Mock
    private ChannelDependencies channelDependencies;

    @Mock
    private SslHandler sslHandler;

    @Mock
    private TlsTcpListener tlsTcpListener;

    @Mock
    private Tls tls;

    @Mock
    private SslFactory ssl;

    @Mock
    private Future<Channel> future;

    @Mock
    private EventLog eventLog;

    @Mock
    private FullConfigurationService fullConfigurationService;

    private ChannelPipeline pipeline;

    private TlsTcpChannelInitializer tlstcpChannelInitializer;

    @Before
    public void before() throws Exception {
        MockitoAnnotations.initMocks(this);

        pipeline = new FakeChannelPipeline();

        when(tlsTcpListener.getTls()).thenReturn(tls);
        when(ssl.getSslHandler(any(SocketChannel.class), any(Tls.class))).thenReturn(sslHandler);
        when(sslHandler.handshakeFuture()).thenReturn(future);
        when(socketChannel.pipeline()).thenReturn(pipeline);
        when(socketChannel.attr(any(AttributeKey.class))).thenReturn(attribute);
        when(channelDependencies.getConfigurationService()).thenReturn(fullConfigurationService);

        tlstcpChannelInitializer = new TlsTcpChannelInitializer(channelDependencies, tlsTcpListener, ssl, eventLog);

    }

    @Test
    public void test_add_special_handlers() throws Exception {

        pipeline.addLast(AbstractChannelInitializer.FIRST_ABSTRACT_HANDLER, new DummyHandler());

        when(tls.getClientAuthMode()).thenReturn(Tls.ClientAuthMode.REQUIRED);

        tlstcpChannelInitializer.addSpecialHandlers(socketChannel);

        assertEquals(SSL_HANDLER, pipeline.names().get(0));
        assertEquals(SSL_EXCEPTION_HANDLER, pipeline.names().get(1));
        assertEquals(SSL_PARAMETER_HANDLER, pipeline.names().get(2));
        assertEquals(SSL_CLIENT_CERTIFICATE_HANDLER, pipeline.names().get(3));
        assertEquals(AbstractChannelInitializer.FIRST_ABSTRACT_HANDLER, pipeline.names().get(4));
    }

    @Test
    public void test_add_special_handlers_no_cert() throws Exception {

        pipeline.addLast(AbstractChannelInitializer.FIRST_ABSTRACT_HANDLER, new DummyHandler());

        when(tls.getClientAuthMode()).thenReturn(Tls.ClientAuthMode.NONE);

        tlstcpChannelInitializer.addSpecialHandlers(socketChannel);

        assertEquals(SSL_HANDLER, pipeline.names().get(0));
        assertEquals(SSL_EXCEPTION_HANDLER, pipeline.names().get(1));
        assertEquals(SSL_PARAMETER_HANDLER, pipeline.names().get(2));
        assertEquals(AbstractChannelInitializer.FIRST_ABSTRACT_HANDLER, pipeline.names().get(3));
    }

}