/*
 * Copyright 2018 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.netflix.titus.testkit.grpc;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import com.google.common.base.Preconditions;
import com.google.rpc.BadRequest;
import com.netflix.titus.runtime.common.grpc.GrpcClientErrorUtils;
import com.netflix.titus.common.util.rx.ObservableExt;
import com.netflix.titus.runtime.endpoint.v3.grpc.ErrorResponses;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import rx.Observable;
import rx.subjects.PublishSubject;

/**
 * GRPC {@link io.grpc.stub.StreamObserver} implementation for testing.
 */
public class TestStreamObserver<T> extends ServerCallStreamObserver<T> {

    private static final Object EOS_MARKER = new Object();

    private final List<T> emittedItems = new CopyOnWriteArrayList<>();

    private final BlockingQueue<T> availableItems = new LinkedBlockingQueue<>();
    private final PublishSubject<T> eventSubject = PublishSubject.create();

    private final CountDownLatch terminatedLatch = new CountDownLatch(1);
    private volatile Throwable error;
    private volatile RuntimeException mappedError;
    private volatile boolean completed;

    private Runnable onCancelHandler;
    private boolean cancelled;

    @Override
    public void onNext(T value) {
        emittedItems.add(value);
        availableItems.add(value);
        eventSubject.onNext(value);
    }

    @Override
    public void onError(Throwable error) {
        this.error = error;
        this.mappedError = exceptionMapper(error);
        eventSubject.onError(error);
        doFinish();
    }

    @Override
    public void onCompleted() {
        completed = true;
        eventSubject.onCompleted();
        doFinish();
    }

    public Observable<T> toObservable() {
        return eventSubject.compose(ObservableExt.head(() -> emittedItems));
    }

    private void doFinish() {
        availableItems.add((T) EOS_MARKER);
        terminatedLatch.countDown();

    }

    public List<T> getEmittedItems() {
        return new ArrayList<>(emittedItems);
    }

    public T getLast() throws Exception {
        terminatedLatch.await();
        throwIfError();
        if (isTerminated() && !emittedItems.isEmpty()) {
            return emittedItems.get(0);
        }
        throw new IllegalStateException("No item emitted by the stream");
    }

    public T getLast(long timout, TimeUnit timeUnit) throws InterruptedException {
        terminatedLatch.await(timout, timeUnit);
        return isTerminated() && !emittedItems.isEmpty() ? emittedItems.get(0) : null;
    }

    public T takeNext() {
        T next = availableItems.poll();
        if ((next == null || next == EOS_MARKER) && isTerminated()) {
            throw new IllegalStateException("Stream is already closed");
        }
        return next;
    }

    public T takeNext(long timeout, TimeUnit timeUnit) throws InterruptedException, IllegalStateException {
        if (isTerminated()) {
            return takeNext();
        }
        T next = availableItems.poll(timeout, timeUnit);
        if (next == EOS_MARKER) {
            throwIfError();
            throw new IllegalStateException("Stream is already closed");
        }
        return next;
    }

    public boolean isTerminated() {
        return completed || error != null;
    }

    public boolean isCompleted() {
        return completed;
    }

    public boolean hasError() {
        return error != null;
    }

    public Throwable getError() {
        Preconditions.checkState(error != null, "Error not emitted");
        return error;
    }

    public Throwable getMappedError() {
        Preconditions.checkState(error != null, "Error not emitted");
        return mappedError;
    }

    public void awaitDone() throws InterruptedException {
        terminatedLatch.await();
        if (hasError()) {
            throw new IllegalStateException("GRPC stream terminated with an error", error);
        }
    }

    public void awaitDone(long timeout, TimeUnit timeUnit) throws InterruptedException {
        if (!terminatedLatch.await(timeout, timeUnit)) {
            throw new IllegalStateException("GRPC request not completed in time");
        }
        if (hasError()) {
            throw new IllegalStateException("GRPC stream terminated with an error", error);
        }
    }

    @Override
    public boolean isReady() {
        return true;
    }

    @Override
    public boolean isCancelled() {
        return cancelled;
    }

    @Override
    public void setOnCancelHandler(Runnable onCancelHandler) {
        this.onCancelHandler = onCancelHandler;
    }

    @Override
    public void setCompression(String compression) {
    }

    @Override
    public void setOnReadyHandler(Runnable onReadyHandler) {
    }

    @Override
    public void disableAutoInboundFlowControl() {
    }

    public void cancel() {
        this.cancelled = true;
        if (onCancelHandler != null) {
            onCancelHandler.run();
        }
    }

    @Override
    public void request(int count) {
    }

    @Override
    public void setMessageCompression(boolean enable) {
    }

    private RuntimeException exceptionMapper(Throwable error) {
        if (error instanceof StatusRuntimeException) {
            StatusRuntimeException e = (StatusRuntimeException) error;
            String errorMessage = "GRPC status " + e.getStatus() + ": " + e.getTrailers().get(ErrorResponses.KEY_TITUS_ERROR_REPORT);

            Optional<BadRequest> badRequest = GrpcClientErrorUtils.getDetail(e, BadRequest.class);
            if (badRequest.isPresent()) {
                return new RuntimeException(errorMessage + ". Invalid field values: " + badRequest, error);
            }
            return new RuntimeException(errorMessage, error);
        }
        return new RuntimeException(error.getMessage(), error);
    }

    private void throwIfError() throws RuntimeException {
        if (error != null) {
            throw mappedError;
        }
    }
}