/* * Copyright 2016 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.client; import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.ContentTooLargeException; import com.linecorp.armeria.common.ProtocolViolationException; import com.linecorp.armeria.common.stream.ClosedStreamException; import com.linecorp.armeria.common.unsafe.PooledHttpData; import com.linecorp.armeria.internal.common.ArmeriaHttpUtil; import com.linecorp.armeria.internal.common.InboundTrafficController; import com.linecorp.armeria.internal.common.KeepAliveHandler; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandler; import io.netty.channel.EventLoop; import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpStatusClass; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.LastHttpContent; import io.netty.util.ReferenceCountUtil; final class Http1ResponseDecoder extends HttpResponseDecoder implements ChannelInboundHandler { private static final Logger logger = LoggerFactory.getLogger(Http1ResponseDecoder.class); private enum State { NEED_HEADERS, NEED_INFORMATIONAL_DATA, NEED_DATA_OR_TRAILERS, DISCARD } /** The request being decoded currently. */ @Nullable private HttpResponseWrapper res; @Nullable private KeepAliveHandler keepAliveHandler; private int resId = 1; private int lastPingReqId = -1; private State state = State.NEED_HEADERS; Http1ResponseDecoder(Channel channel) { super(channel, InboundTrafficController.ofHttp1(channel)); } @Override HttpResponseWrapper addResponse( int id, DecodedHttpResponse res, @Nullable ClientRequestContext ctx, EventLoop eventLoop, long responseTimeoutMillis, long maxContentLength) { final HttpResponseWrapper resWrapper = super.addResponse(id, res, ctx, eventLoop, responseTimeoutMillis, maxContentLength); resWrapper.whenComplete().handle((unused, cause) -> { if (eventLoop.inEventLoop()) { onWrapperCompleted(resWrapper, cause); } else { eventLoop.execute(() -> onWrapperCompleted(resWrapper, cause)); } return null; }); return resWrapper; } private void onWrapperCompleted(HttpResponseWrapper resWrapper, @Nullable Throwable cause) { // Cancel timeout future and abort the request if it exists. resWrapper.onSubscriptionCancelled(cause); if (cause != null) { // Disconnect when the response has been closed with an exception because there's no way // to recover from it in HTTP/1. channel().close(); } } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { maybeInitializeKeepAliveHandler(ctx); } @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { destroyKeepAliveHandler(); } @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { maybeInitializeKeepAliveHandler(ctx); ctx.fireChannelRegistered(); } @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { ctx.fireChannelUnregistered(); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { maybeInitializeKeepAliveHandler(ctx); ctx.fireChannelActive(); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (res != null) { res.close(ClosedSessionException.get()); } destroyKeepAliveHandler(); ctx.fireChannelInactive(); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (!(msg instanceof HttpObject)) { ctx.fireChannelRead(msg); return; } if (isPing()) { onPingRead(msg); ReferenceCountUtil.release(msg); return; } try { switch (state) { case NEED_HEADERS: if (msg instanceof HttpResponse) { final HttpResponse nettyRes = (HttpResponse) msg; final DecoderResult decoderResult = nettyRes.decoderResult(); if (!decoderResult.isSuccess()) { fail(ctx, new ProtocolViolationException(decoderResult.cause())); return; } if (!HttpUtil.isKeepAlive(nettyRes)) { disconnectWhenFinished(); } final HttpResponseWrapper res = getResponse(resId); assert res != null; this.res = res; res.logResponseFirstBytesTransferred(); if (nettyRes.status().codeClass() == HttpStatusClass.INFORMATIONAL) { state = State.NEED_INFORMATIONAL_DATA; } else { state = State.NEED_DATA_OR_TRAILERS; } res.initTimeout(); if (!res.tryWrite(ArmeriaHttpUtil.toArmeria(nettyRes))) { fail(ctx, ClosedStreamException.get()); return; } } else { failWithUnexpectedMessageType(ctx, msg, HttpResponse.class); } break; case NEED_INFORMATIONAL_DATA: if (msg instanceof LastHttpContent) { state = State.NEED_HEADERS; } else { failWithUnexpectedMessageType(ctx, msg, LastHttpContent.class); } break; case NEED_DATA_OR_TRAILERS: if (msg instanceof HttpContent) { final HttpContent content = (HttpContent) msg; final DecoderResult decoderResult = content.decoderResult(); if (!decoderResult.isSuccess()) { fail(ctx, new ProtocolViolationException(decoderResult.cause())); return; } final ByteBuf data = content.content(); final int dataLength = data.readableBytes(); if (dataLength > 0) { assert res != null; final long maxContentLength = res.maxContentLength(); if (maxContentLength > 0 && res.writtenBytes() > maxContentLength - dataLength) { fail(ctx, ContentTooLargeException.get()); return; } else if (!res.tryWrite(PooledHttpData.wrap(data.retain()))) { fail(ctx, ClosedStreamException.get()); return; } } if (msg instanceof LastHttpContent) { final HttpResponseWrapper res = removeResponse(resId++); assert res != null; assert this.res == res; this.res = null; state = State.NEED_HEADERS; final HttpHeaders trailingHeaders = ((LastHttpContent) msg).trailingHeaders(); if (!trailingHeaders.isEmpty() && !res.tryWrite(ArmeriaHttpUtil.toArmeria(trailingHeaders))) { fail(ctx, ClosedStreamException.get()); return; } res.close(); if (needsToDisconnectNow()) { ctx.close(); } } } else { failWithUnexpectedMessageType(ctx, msg, HttpContent.class); } break; case DISCARD: break; } } finally { ReferenceCountUtil.release(msg); } } private void failWithUnexpectedMessageType(ChannelHandlerContext ctx, Object msg, Class<?> expected) { fail(ctx, new ProtocolViolationException( "unexpected message type: " + msg.getClass().getName() + " (expected: " + expected.getName() + ')')); } private void fail(ChannelHandlerContext ctx, Throwable cause) { state = State.DISCARD; final HttpResponseWrapper res = this.res; this.res = null; if (res != null) { res.close(cause); } else { logger.warn("Unexpected exception:", cause); } ctx.close(); } @Override public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { ctx.fireChannelReadComplete(); } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { ctx.fireUserEventTriggered(evt); } @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { ctx.fireChannelWritabilityChanged(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.fireExceptionCaught(cause); } void setKeepAliveHandler(ChannelHandlerContext ctx, KeepAliveHandler keepAliveHandler) { this.keepAliveHandler = keepAliveHandler; maybeInitializeKeepAliveHandler(ctx); } private void maybeInitializeKeepAliveHandler(ChannelHandlerContext ctx) { if (keepAliveHandler != null && ctx.channel().isActive()) { keepAliveHandler.initialize(ctx); } } private void destroyKeepAliveHandler() { if (keepAliveHandler != null) { keepAliveHandler.destroy(); } } private void onPingRead(Object msg) { if (msg instanceof HttpResponse) { assert keepAliveHandler != null; keepAliveHandler.onPing(); } if (msg instanceof LastHttpContent) { onPingComplete(); } } void setPingReqId(int id) { lastPingReqId = id; } boolean isPingReqId(int id) { return lastPingReqId == id; } private boolean isPing() { return lastPingReqId == resId; } private void onPingComplete() { lastPingReqId = -1; resId++; } }