/* * This file is part of Discord4J. * * Discord4J is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Discord4J is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with Discord4J. If not, see <http://www.gnu.org/licenses/>. */ package discord4j.gateway; import discord4j.common.GitProperties; import discord4j.common.LogUtil; import discord4j.common.ResettableInterval; import discord4j.common.close.CloseException; import discord4j.common.close.CloseStatus; import discord4j.common.close.DisconnectBehavior; import discord4j.common.operator.RateLimitOperator; import discord4j.common.retry.ReconnectContext; import discord4j.common.retry.ReconnectOptions; import discord4j.discordjson.json.gateway.*; import discord4j.gateway.json.GatewayPayload; import discord4j.gateway.limiter.PayloadTransformer; import discord4j.gateway.payload.PayloadReader; import discord4j.gateway.payload.PayloadWriter; import discord4j.gateway.retry.GatewayException; import discord4j.gateway.retry.GatewayRetrySpec; import discord4j.gateway.retry.GatewayStateChange; import discord4j.gateway.retry.ReconnectException; import io.netty.buffer.ByteBuf; import io.netty.util.IllegalReferenceCountException; import org.reactivestreams.Publisher; import reactor.core.publisher.*; import reactor.netty.ConnectionObserver; import reactor.netty.http.client.WebsocketClientSpec; import reactor.util.Logger; import reactor.util.Loggers; import reactor.util.context.Context; import reactor.util.retry.Retry; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Objects; import java.util.Properties; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static discord4j.common.LogUtil.format; import static io.netty.handler.codec.http.HttpHeaderNames.USER_AGENT; import static reactor.function.TupleUtils.consumer; /** * Represents a Discord WebSocket client, called Gateway, implementing its lifecycle. * <p> * Keeps track of a single websocket session by wrapping an instance of {@link GatewayWebsocketHandler} each time a * new WebSocket connection to Discord is made, therefore only one instance of this class is enough to * handle the lifecycle of the Gateway operations, that could span multiple WebSocket sessions over time. * <p> * Provides automatic reconnecting through a configurable retry policy, allows consumers to receive inbound events * through {@link #dispatch()}, mapped payloads through {@link #receiver()} and allows a producer to * submit events through {@link #sender()}. * <p> * Provides sending raw {@link ByteBuf} payloads through {@link #sendBuffer(Publisher)} and receiving raw * {@link ByteBuf} payloads mapped in-flight using a specified mapper using {@link #receiver(Function)}. */ public class DefaultGatewayClient implements GatewayClient { private static final Logger log = Loggers.getLogger(DefaultGatewayClient.class); private static final Logger senderLog = Loggers.getLogger("discord4j.gateway.protocol.sender"); private static final Logger receiverLog = Loggers.getLogger("discord4j.gateway.protocol.receiver"); // basic properties private final GatewayReactorResources reactorResources; private final PayloadReader payloadReader; private final PayloadWriter payloadWriter; private final ReconnectOptions reconnectOptions; private final ReconnectContext reconnectContext; private final IdentifyOptions identifyOptions; private final String token; private final GatewayObserver observer; private final PayloadTransformer identifyLimiter; private final ResettableInterval heartbeat; private final int maxMissedHeartbeatAck; // reactive pipelines private final EmitterProcessor<ByteBuf> receiver = EmitterProcessor.create(false); private final EmitterProcessor<ByteBuf> sender = EmitterProcessor.create(false); private final EmitterProcessor<Dispatch> dispatch = EmitterProcessor.create(false); private final EmitterProcessor<GatewayPayload<?>> outbound = EmitterProcessor.create(false); private final EmitterProcessor<GatewayPayload<Heartbeat>> heartbeats = EmitterProcessor.create(false); private final FluxSink<ByteBuf> receiverSink; private final FluxSink<ByteBuf> senderSink; private final FluxSink<Dispatch> dispatchSink; private final FluxSink<GatewayPayload<?>> outboundSink; private final FluxSink<GatewayPayload<Heartbeat>> heartbeatSink; private final ReplayProcessor<GatewayConnection.State> state; private final FluxSink<GatewayConnection.State> stateChanges; // mutable state, modified here and at PayloadHandlers private final AtomicInteger sequence = new AtomicInteger(0); private final AtomicReference<String> sessionId = new AtomicReference<>(""); private final AtomicLong lastSent = new AtomicLong(0); private final AtomicLong lastAck = new AtomicLong(0); private final AtomicInteger missedAck = new AtomicInteger(0); private volatile long responseTime = 0; private volatile MonoProcessor<CloseStatus> disconnectNotifier; private volatile GatewayWebsocketHandler sessionHandler; /** * Initializes a new GatewayClient. * * @param options the {@link GatewayOptions} to configure this client */ public DefaultGatewayClient(GatewayOptions options) { this.token = Objects.requireNonNull(options.getToken()); this.reactorResources = Objects.requireNonNull(options.getReactorResources()); this.payloadReader = Objects.requireNonNull(options.getPayloadReader()); this.payloadWriter = Objects.requireNonNull(options.getPayloadWriter()); this.reconnectOptions = options.getReconnectOptions(); this.reconnectContext = new ReconnectContext( this.reconnectOptions.getFirstBackoff(), this.reconnectOptions.getMaxBackoffInterval()); this.identifyOptions = Objects.requireNonNull(options.getIdentifyOptions()); this.observer = options.getInitialObserver(); this.identifyLimiter = Objects.requireNonNull(options.getIdentifyLimiter()); this.maxMissedHeartbeatAck = Math.max(0, options.getMaxMissedHeartbeatAck()); // TODO: consider exposing OverflowStrategy to GatewayOptions this.receiverSink = receiver.sink(FluxSink.OverflowStrategy.BUFFER); this.senderSink = sender.sink(FluxSink.OverflowStrategy.ERROR); this.dispatchSink = dispatch.sink(FluxSink.OverflowStrategy.BUFFER); this.outboundSink = outbound.sink(FluxSink.OverflowStrategy.ERROR); this.heartbeatSink = heartbeats.sink(FluxSink.OverflowStrategy.ERROR); this.heartbeat = new ResettableInterval(this.reactorResources.getTimerTaskScheduler()); SessionInfo resumeSession = this.identifyOptions.getResumeSession().orElse(null); if (resumeSession != null) { this.sequence.set(resumeSession.getSequence()); this.sessionId.set(resumeSession.getId()); this.state = ReplayProcessor.cacheLastOrDefault(GatewayConnection.State.START_RESUMING); } else { this.state = ReplayProcessor.cacheLastOrDefault(GatewayConnection.State.START_IDENTIFYING); } this.stateChanges = state.sink(FluxSink.OverflowStrategy.LATEST); } @Override public Mono<Void> execute(String gatewayUrl) { return Mono.deferWithContext( context -> { disconnectNotifier = MonoProcessor.create(); lastAck.set(0); lastSent.set(0); missedAck.set(0); MonoProcessor<Void> ping = MonoProcessor.create(); // Setup the sending logic from multiple sources into one merged Flux Flux<ByteBuf> heartbeatFlux = heartbeats.flatMap(payload -> Flux.from(payloadWriter.write(payload))); Flux<ByteBuf> identifyFlux = outbound.filter(payload -> Opcode.IDENTIFY.equals(payload.getOp())) .delayUntil(payload -> ping) .flatMap(payload -> Flux.from(payloadWriter.write(payload))) .transform(identifyLimiter); Flux<ByteBuf> payloadFlux = outbound.filter(payload -> !Opcode.IDENTIFY.equals(payload.getOp())) .flatMap(payload -> Flux.from(payloadWriter.write(payload))) .transform(buf -> Flux.merge(buf, sender)) .transform(new RateLimitOperator<>(outboundLimiterCapacity(), Duration.ofSeconds(60), reactorResources.getTimerTaskScheduler(), reactorResources.getPayloadSenderScheduler())); Flux<ByteBuf> outFlux = Flux.merge(heartbeatFlux, identifyFlux, payloadFlux) .doOnNext(buf -> logPayload(senderLog, context, buf)); sessionHandler = new GatewayWebsocketHandler(receiverSink, outFlux, context); Mono<Void> readyHandler = dispatch.filter(DefaultGatewayClient::isReadyOrResumed) .zipWith(state.next().repeat()) .doOnNext(consumer((event, currentState) -> { ConnectionObserver.State observerState; if (currentState == GatewayConnection.State.START_IDENTIFYING || currentState == GatewayConnection.State.START_RESUMING) { log.info(format(context, "Connected to Gateway")); dispatchSink.next(GatewayStateChange.connected()); observerState = GatewayObserver.CONNECTED; } else { log.info(format(context, "Reconnected to Gateway")); dispatchSink.next(GatewayStateChange.retrySucceeded(reconnectContext.getAttempts())); observerState = GatewayObserver.RETRY_SUCCEEDED; } reconnectContext.reset(); stateChanges.next(GatewayConnection.State.CONNECTED); notifyObserver(observerState); })) .then(); // Subscribe the receiver to process and transform the inbound payloads into Dispatch events Mono<Void> receiverFuture = receiver.map(ByteBuf::retain) .doOnNext(buf -> logPayload(receiverLog, context, buf)) .flatMap(payloadReader::read) .doOnDiscard(ByteBuf.class, DefaultGatewayClient::safeRelease) .doOnNext(payload -> { if (Opcode.HEARTBEAT_ACK.equals(payload.getOp())) { ping.onComplete(); } }) .map(this::updateSequence) .map(payload -> new PayloadContext<>(payload, sessionHandler, this, context)) .flatMap(PayloadHandlers::handle) .then(); // Subscribe the handler's outbound exchange with our outbound signals // routing completion signals to close the gateway Mono<Void> senderFuture = outbound.doOnComplete(sessionHandler::close) .doOnNext(payload -> { if (Opcode.RECONNECT.equals(payload.getOp())) { sessionHandler.error( new GatewayException(context, "Reconnecting due to user action")); } }) .then(); // Create the heartbeat loop, and subscribe it using the sender sink Mono<Void> heartbeatHandler = heartbeat.ticks() .flatMap(t -> { long now = System.nanoTime(); lastAck.compareAndSet(0, now); long delay = now - lastAck.get(); if (lastSent.get() - lastAck.get() > 0) { if (missedAck.incrementAndGet() > maxMissedHeartbeatAck) { log.warn(format(context, "Missing heartbeat ACK for {} (tick: {}, seq: {})"), Duration.ofNanos(delay), t, sequence.get()); sessionHandler.error(new GatewayException(context, "Reconnecting due to zombie or failed connection")); return Mono.empty(); } } log.debug(format(context, "Sending heartbeat {} after last ACK"), Duration.ofNanos(delay)); lastSent.set(now); return Mono.just(GatewayPayload.heartbeat(ImmutableHeartbeat.of(sequence.get()))); }) .doOnNext(heartbeatSink::next) .then(); Mono<Void> httpFuture = reactorResources.getHttpClient() .headers(headers -> headers.add(USER_AGENT, initUserAgent())) .observe(getObserver(context)) .websocket(WebsocketClientSpec.builder() .maxFramePayloadLength(Integer.MAX_VALUE) .build()) .uri(gatewayUrl) .handle(sessionHandler::handle) .subscriberContext(LogUtil.clearContext()) .flatMap(t2 -> handleClose(t2.getT1(), t2.getT2())) .then(); return Mono.zip(httpFuture, readyHandler, receiverFuture, senderFuture, heartbeatHandler) .doOnError(t -> { if (t instanceof ReconnectException) { log.info(format(context, "{}"), t.getMessage()); } else { if (log.isTraceEnabled()) { log.error(format(context, "Gateway client error"), t); } else { log.error(format(context, "{}"), t.toString()); } } }) .doOnTerminate(heartbeat::stop) .doOnCancel(() -> sessionHandler.close()) .then(); }) .subscriberContext(ctx -> ctx.put(LogUtil.KEY_SHARD_ID, identifyOptions.getShardInfo().getIndex())) .retryWhen(retryFactory()) .then(Mono.defer(() -> disconnectNotifier.then())) .doOnSubscribe(s -> { if (disconnectNotifier != null) { throw new IllegalStateException("execute can only be subscribed once"); } }); } private String initUserAgent() { final Properties properties = GitProperties.getProperties(); final String version = properties.getProperty(GitProperties.APPLICATION_VERSION, "3"); final String url = properties.getProperty(GitProperties.APPLICATION_URL, "https://discord4j.com"); return "DiscordBot(" + url + ", " + version + ")"; } private void logPayload(Logger logger, Context context, ByteBuf buf) { logger.trace(format(context, buf.toString(StandardCharsets.UTF_8) .replaceAll("(\"token\": ?\")([A-Za-z0-9._-]*)(\")", "$1hunter2$3"))); } private static boolean isReadyOrResumed(Dispatch d) { return Ready.class.isAssignableFrom(d.getClass()) || Resumed.class.isAssignableFrom(d.getClass()); } private GatewayPayload<?> updateSequence(GatewayPayload<?> payload) { if (payload.getSequence() != null) { sequence.set(payload.getSequence()); notifyObserver(GatewayObserver.SEQUENCE); } return payload; } private Retry retryFactory() { return GatewayRetrySpec.create(reconnectOptions, reconnectContext) .doBeforeRetry(retry -> { stateChanges.next(retry.nextState()); long attempt = retry.iteration(); Duration backoff = retry.nextBackoff(); log.debug(format(getContextFromException(retry.failure()), "{} in {} (attempts: {})"), retry.nextState(), backoff, attempt); if (retry.iteration() == 1) { if (retry.nextState() == GatewayConnection.State.RESUMING) { dispatchSink.next(GatewayStateChange.retryStarted(backoff)); notifyObserver(GatewayObserver.RETRY_STARTED); } else { dispatchSink.next(GatewayStateChange.retryStartedResume(backoff)); notifyObserver(GatewayObserver.RETRY_RESUME_STARTED); } } else { dispatchSink.next(GatewayStateChange.retryFailed(attempt - 1, backoff)); notifyObserver(GatewayObserver.RETRY_FAILED); } }); } private Context getContextFromException(Throwable t) { if (t instanceof CloseException) { return ((CloseException) t).getContext(); } if (t instanceof GatewayException) { return ((GatewayException) t).getContext(); } return Context.empty(); } private Mono<CloseStatus> handleClose(DisconnectBehavior sourceBehavior, CloseStatus closeStatus) { return Mono.deferWithContext(ctx -> { DisconnectBehavior behavior; if (GatewayRetrySpec.NON_RETRYABLE_STATUS_CODES.contains(closeStatus.getCode())) { // non-retryable close codes are non-transient errors therefore stopping is the only choice behavior = DisconnectBehavior.stop(sourceBehavior.getCause()); } else { behavior = sourceBehavior; } log.debug(format(ctx, "Closing and {} with status {}"), behavior, closeStatus); stateChanges.next(GatewayConnection.State.DISCONNECTING); heartbeat.stop(); if (behavior.getAction() == DisconnectBehavior.Action.STOP_ABRUPTLY) { dispatchSink.next(GatewayStateChange.disconnectedResume()); notifyObserver(GatewayObserver.DISCONNECTED_RESUME); } else if (behavior.getAction() == DisconnectBehavior.Action.STOP) { dispatchSink.next(GatewayStateChange.disconnected(sourceBehavior, closeStatus)); sequence.set(0); sessionId.set(""); notifyObserver(GatewayObserver.DISCONNECTED); } switch (behavior.getAction()) { case STOP_ABRUPTLY: case STOP: reconnectContext.clear(); responseTime = 0; lastSent.set(0); lastAck.set(0); stateChanges.next(GatewayConnection.State.DISCONNECTED); if (behavior.getCause() != null) { return Mono.just(new CloseException(closeStatus, ctx, behavior.getCause())) .flatMap(ex -> { disconnectNotifier.onError(ex); return Mono.error(ex); }); } return Mono.just(closeStatus).doOnNext(status -> disconnectNotifier.onNext(closeStatus)); case RETRY_ABRUPTLY: case RETRY: default: return Mono.error(new CloseException(closeStatus, ctx, behavior.getCause())); } }); } private ConnectionObserver getObserver(Context context) { return (connection, newState) -> { log.debug(format(context, "{} {}"), newState, connection); notifyObserver(newState); }; } private void notifyObserver(ConnectionObserver.State state) { observer.onStateChange(state, this); } @Override public Mono<Void> close(boolean allowResume) { return Mono.defer(() -> { if (sessionHandler == null || disconnectNotifier == null) { return Mono.error(new IllegalStateException("Gateway client is not active!")); } if (!disconnectNotifier.isTerminated()) { if (allowResume) { sessionHandler.close(DisconnectBehavior.stopAbruptly(null)); } else { sessionHandler.close(DisconnectBehavior.stop(null)); } } return disconnectNotifier.then(); }); } @Override public Flux<Dispatch> dispatch() { return dispatch; } @Override public Flux<GatewayPayload<?>> receiver() { return receiver(payloadReader::read); } @Override public <T> Flux<T> receiver(Function<ByteBuf, Publisher<? extends T>> mapper) { return receiver.map(ByteBuf::retainedDuplicate) .doOnDiscard(ByteBuf.class, DefaultGatewayClient::safeRelease) .flatMap(mapper); } private static void safeRelease(ByteBuf buf) { if (buf.refCnt() > 0) { try { buf.release(); } catch (IllegalReferenceCountException e) { if (log.isDebugEnabled()) { log.debug("", e); } } } } @Override public FluxSink<GatewayPayload<?>> sender() { return outboundSink; } @Override public Mono<Void> sendBuffer(Publisher<ByteBuf> publisher) { return Flux.from(publisher).doOnNext(senderSink::next).then(); } @Override public int getShardCount() { return identifyOptions.getShardInfo().getCount(); } @Override public String getSessionId() { return sessionId.get(); } @Override public int getSequence() { return sequence.get(); } @Override public Flux<GatewayConnection.State> stateEvents() { return state; } @Override public Mono<Boolean> isConnected() { return state.next() .filter(s -> s == GatewayConnection.State.CONNECTED) .hasElement() .defaultIfEmpty(false); } @Override public Duration getResponseTime() { return Duration.ofNanos(responseTime); } ///////////////////////////////// // Methods for PayloadHandlers // ///////////////////////////////// void ackHeartbeat() { responseTime = lastAck.updateAndGet(x -> System.nanoTime()) - lastSent.get(); missedAck.set(0); } //////////////////////////////// // Fields for PayloadHandlers // //////////////////////////////// /** * Obtains the FluxSink to send Dispatch events towards GatewayClient's users. * * @return a {@link FluxSink} for {@link Dispatch} * objects */ FluxSink<Dispatch> dispatchSink() { return dispatchSink; } /** * Gets the atomic reference for the current heartbeat sequence. * * @return an AtomicInteger representing the current gateway sequence */ AtomicInteger sequence() { return sequence; } /** * Gets the atomic reference for the current session ID. * * @return an AtomicReference of the String representing the current session ID */ AtomicReference<String> sessionId() { return sessionId; } /** * Gets the heartbeat manager bound to this GatewayClient. * * @return a {@link ResettableInterval} to manipulate heartbeat operations */ ResettableInterval heartbeat() { return heartbeat; } /** * Gets the token used to connect to the gateway. * * @return a token String */ String token() { return token; } /** * Gets the configuration object for gateway identifying procedure. * * @return an IdentifyOptions configuration object */ IdentifyOptions identifyOptions() { return identifyOptions; } /** * JVM property that allows modifying the number of outbound payloads permitted before activating the * rate-limiter and delaying every following payload for 60 seconds. Default value: 115 permits */ private static final String OUTBOUND_CAPACITY_PROPERTY = "discord4j.gateway.outbound.capacity"; private int outboundLimiterCapacity() { String capacityValue = System.getProperty(OUTBOUND_CAPACITY_PROPERTY); if (capacityValue != null) { try { int capacity = Integer.parseInt(capacityValue); log.info("Overriding default outbound limiter capacity: {}", capacity); } catch (NumberFormatException e) { log.warn("Invalid custom outbound limiter capacity: {}", capacityValue); } } return 115; } }