/*
 * Copyright 2016 The gRPC Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.grpc.cronet;

import static io.grpc.internal.GrpcUtil.CONTENT_TYPE_KEY;
import static io.grpc.internal.GrpcUtil.TE_HEADER;
import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;

// TODO(ericgribkoff): Consider changing from android.util.Log to java logging.
import android.util.Log;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.io.BaseEncoding;
import io.grpc.Attributes;
import io.grpc.CallOptions;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.cronet.CronetChannelBuilder.StreamBuilderFactory;
import io.grpc.internal.AbstractClientStream;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.Http2ClientStreamTransportState;
import io.grpc.internal.ReadableBuffers;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportFrameUtil;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import org.chromium.net.BidirectionalStream;
import org.chromium.net.CronetException;
import org.chromium.net.ExperimentalBidirectionalStream;
import org.chromium.net.UrlResponseInfo;

/**
 * Client stream for the cronet transport.
 */
class CronetClientStream extends AbstractClientStream {
  private static final int READ_BUFFER_CAPACITY = 4 * 1024;
  private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
  private static final String LOG_TAG = "grpc-java-cronet";
  private final String url;
  private final String userAgent;
  private final StatsTraceContext statsTraceCtx;
  private final Executor executor;
  private final Metadata headers;
  private final CronetClientTransport transport;
  private final Runnable startCallback;
  @VisibleForTesting
  final boolean idempotent;
  private BidirectionalStream stream;
  private final boolean delayRequestHeader;
  private final Object annotation;
  private final Collection<Object> annotations;
  private final TransportState state;
  private final Sink sink = new Sink();
  private StreamBuilderFactory streamFactory;

  CronetClientStream(
      final String url,
      @Nullable String userAgent,
      Executor executor,
      final Metadata headers,
      CronetClientTransport transport,
      Runnable startCallback,
      Object lock,
      int maxMessageSize,
      boolean alwaysUsePut,
      MethodDescriptor<?, ?> method,
      StatsTraceContext statsTraceCtx,
      CallOptions callOptions,
      TransportTracer transportTracer) {
    super(
        new CronetWritableBufferAllocator(), statsTraceCtx, transportTracer, headers, callOptions,
        method.isSafe());
    this.url = Preconditions.checkNotNull(url, "url");
    this.userAgent = Preconditions.checkNotNull(userAgent, "userAgent");
    this.statsTraceCtx = Preconditions.checkNotNull(statsTraceCtx, "statsTraceCtx");
    this.executor = Preconditions.checkNotNull(executor, "executor");
    this.headers = Preconditions.checkNotNull(headers, "headers");
    this.transport = Preconditions.checkNotNull(transport, "transport");
    this.startCallback = Preconditions.checkNotNull(startCallback, "startCallback");
    this.idempotent = method.isIdempotent() || alwaysUsePut;
    // Only delay flushing header for unary rpcs.
    this.delayRequestHeader = (method.getType() == MethodDescriptor.MethodType.UNARY);
    this.annotation = callOptions.getOption(CronetCallOptions.CRONET_ANNOTATION_KEY);
    this.annotations = callOptions.getOption(CronetCallOptions.CRONET_ANNOTATIONS_KEY);
    this.state = new TransportState(maxMessageSize, statsTraceCtx, lock, transportTracer);
  }

  @Override
  protected TransportState transportState() {
    return state;
  }

  @Override
  protected Sink abstractClientStreamSink() {
    return sink;
  }

  @Override
  public void setAuthority(String authority) {
    throw new UnsupportedOperationException("Cronet does not support overriding authority");
  }

  class Sink implements AbstractClientStream.Sink {
    @Override
    public void writeHeaders(Metadata metadata, byte[] payload) {
      startCallback.run();

      BidirectionalStreamCallback callback = new BidirectionalStreamCallback();
      String path = url;
      if (payload != null) {
        path += "?" + BaseEncoding.base64().encode(payload);
      }
      BidirectionalStream.Builder builder =
          streamFactory.newBidirectionalStreamBuilder(path, callback, executor);
      if (payload != null) {
        builder.setHttpMethod("GET");
      } else if (idempotent) {
        builder.setHttpMethod("PUT");
      }
      if (delayRequestHeader) {
        builder.delayRequestHeadersUntilFirstFlush(true);
      }
      if (annotation != null) {
        ((ExperimentalBidirectionalStream.Builder) builder).addRequestAnnotation(annotation);
      }
      if (annotations != null) {
        for (Object o : annotations) {
          ((ExperimentalBidirectionalStream.Builder) builder).addRequestAnnotation(o);
        }
      }
      setGrpcHeaders(builder);
      stream = builder.build();
      stream.start();
    }

    @Override
    public void writeFrame(
        WritableBuffer buffer, boolean endOfStream, boolean flush, int numMessages) {
      synchronized (state.lock) {
        if (state.cancelSent) {
          return;
        }
        ByteBuffer byteBuffer;
        if (buffer != null) {
          byteBuffer = ((CronetWritableBuffer) buffer).buffer();
          byteBuffer.flip();
        } else {
          byteBuffer = EMPTY_BUFFER;
        }
        onSendingBytes(byteBuffer.remaining());
        if (!state.streamReady) {
          state.enqueuePendingData(new PendingData(byteBuffer, endOfStream, flush));
        } else {
          streamWrite(byteBuffer, endOfStream, flush);
        }
      }
    }

    @Override
    public void request(final int numMessages) {
      synchronized (state.lock) {
        state.requestMessagesFromDeframer(numMessages);
      }
    }

    @Override
    public void cancel(Status reason) {
      synchronized (state.lock) {
        if (state.cancelSent) {
          return;
        }
        state.cancelSent = true;
        state.cancelReason = reason;
        state.clearPendingData();
        if (stream != null) {
          // Will report stream finish when BidirectionalStreamCallback.onCanceled is called.
          stream.cancel();
        } else {
          transport.finishStream(CronetClientStream.this, reason);
        }
      }
    }
  }

  class TransportState extends Http2ClientStreamTransportState {
    private final Object lock;
    @GuardedBy("lock")
    private Queue<PendingData> pendingData = new LinkedList<PendingData>();
    @GuardedBy("lock")
    private boolean streamReady;
    @GuardedBy("lock")
    private boolean cancelSent = false;
    @GuardedBy("lock")
    private int bytesPendingProcess;
    @GuardedBy("lock")
    private Status cancelReason;
    @GuardedBy("lock")
    private boolean readClosed;
    @GuardedBy("lock")
    private boolean firstWriteComplete;

    public TransportState(
        int maxMessageSize, StatsTraceContext statsTraceCtx, Object lock,
        TransportTracer transportTracer) {
      super(maxMessageSize, statsTraceCtx, transportTracer);
      this.lock = Preconditions.checkNotNull(lock, "lock");
    }

    @GuardedBy("lock")
    public void start(StreamBuilderFactory factory) {
      streamFactory = factory;
    }

    @GuardedBy("lock")
    @Override
    protected void onStreamAllocated() {
      super.onStreamAllocated();
    }

    @GuardedBy("lock")
    @Override
    protected void http2ProcessingFailed(Status status, boolean stopDelivery, Metadata trailers) {
      stream.cancel();
      transportReportStatus(status, stopDelivery, trailers);
    }

    @GuardedBy("lock")
    @Override
    public void deframeFailed(Throwable cause) {
      http2ProcessingFailed(Status.fromThrowable(cause), true, new Metadata());
    }

    @Override
    public void runOnTransportThread(final Runnable r) {
      synchronized (lock) {
        r.run();
      }
    }

    @GuardedBy("lock")
    @Override
    public void bytesRead(int processedBytes) {
      bytesPendingProcess -= processedBytes;
      if (bytesPendingProcess == 0 && !readClosed) {
        if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
          Log.v(LOG_TAG, "BidirectionalStream.read");
        }
        stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY));
      }
    }

    @GuardedBy("lock")
    private void transportHeadersReceived(Metadata metadata, boolean endOfStream) {
      if (endOfStream) {
        transportTrailersReceived(metadata);
      } else {
        transportHeadersReceived(metadata);
      }
    }

    @GuardedBy("lock")
    private void transportDataReceived(ByteBuffer buffer, boolean endOfStream) {
      bytesPendingProcess += buffer.remaining();
      super.transportDataReceived(ReadableBuffers.wrap(buffer), endOfStream);
    }

    @GuardedBy("lock")
    private void clearPendingData() {
      for (PendingData data : pendingData) {
        data.buffer.clear();
      }
      pendingData.clear();
    }

    @GuardedBy("lock")
    private void enqueuePendingData(PendingData data) {
      pendingData.add(data);
    }

    @GuardedBy("lock")
    private void writeAllPendingData() {
      for (PendingData data : pendingData) {
        streamWrite(data.buffer, data.endOfStream, data.flush);
      }
      pendingData.clear();
    }
  }

  // TODO(ericgribkoff): move header related method to a common place like GrpcUtil.
  private static boolean isApplicationHeader(String key) {
    // Don't allow reserved non HTTP/2 pseudo headers to be added
    // HTTP/2 headers can not be created as keys because Header.Key disallows the ':' character.
    return !CONTENT_TYPE_KEY.name().equalsIgnoreCase(key)
        && !USER_AGENT_KEY.name().equalsIgnoreCase(key)
        && !TE_HEADER.name().equalsIgnoreCase(key);
  }

  private void setGrpcHeaders(BidirectionalStream.Builder builder) {
    // Psuedo-headers are set by cronet.
    // All non-pseudo headers must come after pseudo headers.
    // TODO(ericgribkoff): remove this and set it on CronetEngine after crbug.com/588204 gets fixed.
    builder.addHeader(USER_AGENT_KEY.name(), userAgent);
    builder.addHeader(CONTENT_TYPE_KEY.name(), GrpcUtil.CONTENT_TYPE_GRPC);
    builder.addHeader("te", GrpcUtil.TE_TRAILERS);

    // Now add any application-provided headers.
    // TODO(ericgribkoff): make a String-based version to avoid unnecessary conversion between
    // String and byte array.
    byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers);
    for (int i = 0; i < serializedHeaders.length; i += 2) {
      String key = new String(serializedHeaders[i], Charset.forName("UTF-8"));
      // TODO(ericgribkoff): log an error or throw an exception
      if (isApplicationHeader(key)) {
        String value = new String(serializedHeaders[i + 1], Charset.forName("UTF-8"));
        builder.addHeader(key, value);
      }
    }
  }

  private void streamWrite(ByteBuffer buffer, boolean endOfStream, boolean flush) {
    if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
      Log.v(LOG_TAG, "BidirectionalStream.write");
    }
    stream.write(buffer, endOfStream);
    if (flush) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "BidirectionalStream.flush");
      }
      stream.flush();
    }
  }

  private void finishStream(Status status) {
    transport.finishStream(this, status);
  }

  @Override
  public Attributes getAttributes() {
    return Attributes.EMPTY;
  }

  class BidirectionalStreamCallback extends BidirectionalStream.Callback {
    private List<Map.Entry<String, String>> trailerList;

    @Override
    public void onStreamReady(BidirectionalStream stream) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onStreamReady");
      }
      synchronized (state.lock) {
        // Now that the stream is ready, call the listener's onReady callback if
        // appropriate.
        state.onStreamAllocated();
        state.streamReady = true;
        state.writeAllPendingData();
      }
    }

    @Override
    public void onResponseHeadersReceived(BidirectionalStream stream, UrlResponseInfo info) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onResponseHeadersReceived. Header=" + info.getAllHeadersAsList());
        Log.v(LOG_TAG, "BidirectionalStream.read");
      }
      reportHeaders(info.getAllHeadersAsList(), false);
      stream.read(ByteBuffer.allocateDirect(READ_BUFFER_CAPACITY));
    }

    @Override
    public void onReadCompleted(BidirectionalStream stream, UrlResponseInfo info,
        ByteBuffer buffer, boolean endOfStream) {
      buffer.flip();
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onReadCompleted. Size=" + buffer.remaining());
      }

      synchronized (state.lock) {
        state.readClosed = endOfStream;
        // The endOfStream in gRPC has a different meaning so we always call transportDataReceived
        // with endOfStream=false.
        if (buffer.remaining() != 0) {
          state.transportDataReceived(buffer, false);
        }
      }
      if (endOfStream && trailerList != null) {
        // Process trailers if we have already received any.
        reportHeaders(trailerList, true);
      }
    }

    @Override
    public void onWriteCompleted(BidirectionalStream stream, UrlResponseInfo info,
        ByteBuffer buffer, boolean endOfStream) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onWriteCompleted");
      }
      synchronized (state.lock) {
        if (!state.firstWriteComplete) {
          // Cronet API doesn't notify when headers are written to wire, but it occurs before first
          // onWriteCompleted callback.
          state.firstWriteComplete = true;
          statsTraceCtx.clientOutboundHeaders();
        }
        state.onSentBytes(buffer.position());
      }
    }

    @Override
    public void onResponseTrailersReceived(BidirectionalStream stream, UrlResponseInfo info,
        UrlResponseInfo.HeaderBlock trailers) {
      processTrailers(trailers.getAsList());
    }

    // We need this method because UrlResponseInfo.HeaderBlock is a final class and cannot be
    // mocked.
    @VisibleForTesting
    void processTrailers(List<Map.Entry<String, String>> trailerList) {
      this.trailerList = trailerList;
      boolean readClosed;
      synchronized (state.lock) {
        readClosed = state.readClosed;
      }
      if (readClosed) {
        // There's no pending onReadCompleted callback so we can report trailers now.
        reportHeaders(trailerList, true);
      }
      // Otherwise report trailers in onReadCompleted, or onSucceeded.
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onResponseTrailersReceived. Trailer=" + trailerList.toString());
      }
    }

    @Override
    public void onSucceeded(BidirectionalStream stream, UrlResponseInfo info) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onSucceeded");
      }

      if (!haveTrailersBeenReported()) {
        if (trailerList != null) {
          reportHeaders(trailerList, true);
        } else if (info != null) {
          reportHeaders(info.getAllHeadersAsList(), true);
        } else {
          throw new AssertionError("No response header or trailer");
        }
      }
      finishStream(toGrpcStatus(info));
    }

    @Override
    public void onFailed(BidirectionalStream stream, UrlResponseInfo info,
        CronetException error) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onFailed");
      }
      finishStream(Status.UNAVAILABLE.withCause(error));
    }

    @Override
    public void onCanceled(BidirectionalStream stream, UrlResponseInfo info) {
      if (Log.isLoggable(LOG_TAG, Log.VERBOSE)) {
        Log.v(LOG_TAG, "onCanceled");
      }
      Status status;
      synchronized (state.lock) {
        if (state.cancelReason != null) {
          status = state.cancelReason;
        } else if (info != null) {
          status = toGrpcStatus(info);
        } else {
          status = Status.CANCELLED.withDescription("stream cancelled without reason");
        }
      }
      finishStream(status);
    }

    private void reportHeaders(List<Map.Entry<String, String>> headers, boolean endOfStream) {
      // TODO(ericgribkoff): create new utility methods to eliminate all these conversions
      List<String> headerList = new ArrayList<>();
      for (Map.Entry<String, String> entry : headers) {
        headerList.add(entry.getKey());
        headerList.add(entry.getValue());
      }

      byte[][] headerValues = new byte[headerList.size()][];
      for (int i = 0; i < headerList.size(); i += 2) {
        headerValues[i] = headerList.get(i).getBytes(Charset.forName("UTF-8"));
        headerValues[i + 1] = headerList.get(i + 1).getBytes(Charset.forName("UTF-8"));
      }
      Metadata metadata =
          InternalMetadata.newMetadata(TransportFrameUtil.toRawSerializedHeaders(headerValues));
      synchronized (state.lock) {
        // There's no pending onReadCompleted callback so we can report trailers now.
        state.transportHeadersReceived(metadata, endOfStream);
      }
    }

    private boolean haveTrailersBeenReported() {
      synchronized (state.lock) {
        return trailerList != null && state.readClosed;
      }
    }

    private Status toGrpcStatus(UrlResponseInfo info) {
      return GrpcUtil.httpStatusToGrpcStatus(info.getHttpStatusCode());
    }
  }

  private static class PendingData {
    ByteBuffer buffer;
    boolean endOfStream;
    boolean flush;

    PendingData(ByteBuffer buffer, boolean endOfStream, boolean flush) {
      this.buffer = buffer;
      this.endOfStream = endOfStream;
      this.flush = flush;
    }
  }
}