/*
 * Copyright (c) 2012-2013 Spotify AB
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.spotify.netty4.handler.codec.zmtp;


import java.util.ArrayList;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.ReferenceCountUtil;

/**
 * Netty ZMTP encoder.
 */
class ZMTPFramingEncoder extends ChannelOutboundHandlerAdapter {

  private final ZMTPEncoder encoder;

  private final List<Object> messages = new ArrayList<Object>();
  private final List<ChannelPromise> promises = new ArrayList<ChannelPromise>();
  private ZMTPWriter writer;
  private ZMTPEstimator estimator;

  ZMTPFramingEncoder(final ZMTPSession session, final ZMTPEncoder encoder) {
    if (session == null) {
      throw new NullPointerException("session");
    }
    if (encoder == null) {
      throw new NullPointerException("encoder");
    }
    this.encoder = encoder;
    this.writer = ZMTPWriter.create(session.negotiatedVersion());
    this.estimator = ZMTPEstimator.create(session.negotiatedVersion());
  }

  public ZMTPFramingEncoder(final ZMTPWireFormat wireFormat, final ZMTPEncoder encoder) {
    if (wireFormat == null) {
      throw new NullPointerException("wireFormat");
    }
    if (encoder == null) {
      throw new NullPointerException("encoder");
    }
    this.encoder = encoder;
    this.writer = new ZMTPWriter(wireFormat);
    this.estimator = new ZMTPEstimator(wireFormat);
  }

  @Override
  public void handlerRemoved(final ChannelHandlerContext ctx) {
    encoder.close();
  }

  @Override
  public void write(final ChannelHandlerContext ctx, final Object msg,
                    final ChannelPromise promise) {
    messages.add(msg);
    promises.add(promise);
  }

  @Override
  public void flush(final ChannelHandlerContext ctx) throws Exception {
    if (messages == null) {
      return;
    }
    estimator.reset();
    for (final Object message : messages) {
      encoder.estimate(message, estimator);
    }
    final ByteBuf output = ctx.alloc().buffer(estimator.size());
    writer.reset(output);
    for (final Object message : messages) {
      encoder.encode(message, writer);
      ReferenceCountUtil.release(message);
    }
    final ChannelPromise aggregate = new AggregatePromise(ctx.channel(), promises);
    messages.clear();
    promises.clear();
    ctx.write(output, aggregate);
    ctx.flush();
  }

  private static class AggregatePromise extends DefaultChannelPromise {

    private final ChannelPromise[] promises;

    private AggregatePromise(final Channel channel,
                             final List<ChannelPromise> promises) {
      super(channel);
      this.promises = promises.toArray(new ChannelPromise[promises.size()]);
    }

    @Override
    public ChannelPromise setSuccess(final Void result) {
      super.setSuccess(result);
      for (final ChannelPromise promise : promises) {
        promise.setSuccess(result);
      }
      return this;
    }

    @Override
    public boolean trySuccess() {
      final boolean result = super.trySuccess();
      for (final ChannelPromise promise : promises) {
        promise.trySuccess();
      }
      return result;
    }

    @Override
    public ChannelPromise setFailure(final Throwable cause) {
      super.setFailure(cause);
      for (final ChannelPromise promise : promises) {
        promise.setFailure(cause);
      }
      return this;
    }
  }
}