* Copyright (c) 2012-2013 Spotify AB
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.

package com.spotify.netty4.handler.codec.zmtp;

import com.google.common.util.concurrent.SettableFuture;

import org.junit.After;
import org.junit.Before;
import org.junit.experimental.theories.Theories;
import org.junit.experimental.theories.Theory;
import org.junit.experimental.theories.suppliers.TestedOn;
import org.junit.runner.RunWith;

import java.net.InetSocketAddress;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;

import static com.spotify.netty4.handler.codec.zmtp.ZMTPProtocols.ZMTP20;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertFalse;

public class ProtocolViolationTests {

  private Channel serverChannel;
  private InetSocketAddress serverAddress;

  private final String identity = "identity";
  private NioEventLoopGroup bossGroup;
  private NioEventLoopGroup group;

  private static class MockHandler extends ChannelInboundHandlerAdapter {

    private SettableFuture<Void> active = SettableFuture.create();
    private SettableFuture<Throwable> exception = SettableFuture.create();
    private SettableFuture<Void> inactive = SettableFuture.create();

    private volatile boolean handshaked;
    private volatile boolean read;

    public void channelActive(final ChannelHandlerContext ctx) throws Exception {

    public void channelInactive(final ChannelHandlerContext ctx) throws Exception {

    public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
      if (evt instanceof ZMTPHandshakeSuccess) {
        handshaked = true;

    public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
      read = true;

    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause)
        throws Exception {

  private final MockHandler mockHandler = new MockHandler();

  public void setup() {
    final ServerBootstrap serverBootstrap = new ServerBootstrap();
    bossGroup = new NioEventLoopGroup(1);
    group = new NioEventLoopGroup();
    serverBootstrap.group(bossGroup, group);
    serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
      protected void initChannel(final NioSocketChannel ch) throws Exception {

    serverChannel = serverBootstrap.bind(new InetSocketAddress("localhost", 0))
    serverAddress = (InetSocketAddress) serverChannel.localAddress();

  public void teardown() {
    if (serverChannel != null) {
    if (bossGroup != null) {
    if (group != null) {

  public void protocolErrorsCauseException(
      @TestedOn(ints = {16, 17, 27, 32, 48, 53}) final int payloadSize) throws Exception {
    final Bootstrap b = new Bootstrap();
    b.group(new NioEventLoopGroup());
    b.handler(new ChannelInitializer<NioSocketChannel>() {
      protected void initChannel(final NioSocketChannel ch) throws Exception {
        ch.pipeline().addLast(new MockHandler());

    final Channel channel = b.connect(serverAddress).awaitUninterruptibly().channel();

    final ByteBuf payload = Unpooled.buffer(payloadSize);
    for (int i = 0; i < payloadSize; i++) {

    mockHandler.active.get(5, SECONDS);
    mockHandler.exception.get(5, SECONDS);
    mockHandler.inactive.get(5, SECONDS);