package com.mautini.assistant.demo.api;

import com.google.assistant.embedded.v1alpha2.AssistConfig;
import com.google.assistant.embedded.v1alpha2.AssistRequest;
import com.google.assistant.embedded.v1alpha2.AssistResponse;
import com.google.assistant.embedded.v1alpha2.AudioInConfig;
import com.google.assistant.embedded.v1alpha2.AudioOutConfig;
import com.google.assistant.embedded.v1alpha2.DeviceConfig;
import com.google.assistant.embedded.v1alpha2.DialogStateIn;
import com.google.assistant.embedded.v1alpha2.EmbeddedAssistantGrpc;
import com.google.assistant.embedded.v1alpha2.SpeechRecognitionResult;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.protobuf.ByteString;
import com.mautini.assistant.demo.authentication.OAuthCredentials;
import com.mautini.assistant.demo.config.AssistantConf;
import com.mautini.assistant.demo.config.IoConf;
import com.mautini.assistant.demo.device.Device;
import com.mautini.assistant.demo.device.DeviceModel;
import com.mautini.assistant.demo.exception.ConverseException;
import io.grpc.CallCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.util.Arrays;
import java.util.Date;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public class AssistantClient implements StreamObserver<AssistResponse> {

    private static final Logger LOGGER = LoggerFactory.getLogger(AssistantClient.class);

    private CountDownLatch finishLatch = new CountDownLatch(1);

    private EmbeddedAssistantGrpc.EmbeddedAssistantStub embeddedAssistantStub;

    private ByteArrayOutputStream currentResponse = new ByteArrayOutputStream();

    // See reference.conf
    private AssistantConf assistantConf;

    // if text inputType, text query is set
    private String textQuery;

    private byte[] audioResponse;

    private String textResponse;

    private IoConf ioConf;

    /**
     * Conversation state to continue a conversation if needed
     *
     * @see <a href="https://developers.google.com/assistant/sdk/reference/rpc/google.assistant.embedded.v1alpha2#google.assistant.embedded.v1alpha2.DialogStateOut.FIELDS.bytes.google.assistant.embedded.v1alpha2.DialogStateOut.conversation_state">Google documentation</a>
     */
    private ByteString currentConversationState;

    private DeviceModel deviceModel;

    private Device device;

    public AssistantClient(OAuthCredentials oAuthCredentials, AssistantConf assistantConf, DeviceModel deviceModel,
                           Device device, IoConf ioConf) {

        this.assistantConf = assistantConf;
        this.deviceModel = deviceModel;
        this.device = device;
        this.currentConversationState = ByteString.EMPTY;
        this.ioConf = ioConf;

        // Create a channel to the test service.
        ManagedChannel channel = ManagedChannelBuilder.forAddress(assistantConf.getAssistantApiEndpoint(), 443)
                .build();

        // Create a stub with credential
        embeddedAssistantStub = EmbeddedAssistantGrpc.newStub(channel);

        updateCredentials(oAuthCredentials);
    }

    /**
     * Get CallCredentials from OAuthCredentials
     *
     * @param oAuthCredentials the credentials from the AuthenticationHelper
     * @return the CallCredentials for the GRPC requests
     */
    private CallCredentials getCallCredentials(OAuthCredentials oAuthCredentials) {

        AccessToken accessToken = new AccessToken(
                oAuthCredentials.getAccessToken(),
                new Date(oAuthCredentials.getExpirationTime())
        );

        OAuth2Credentials oAuth2Credentials = new OAuth2Credentials(accessToken);

        // Create an instance of {@link io.grpc.CallCredentials}
        return MoreCallCredentials.from(oAuth2Credentials);
    }

    /**
     * Update the credentials used to request the api
     *
     * @param oAuthCredentials the new credentials
     */
    public void updateCredentials(OAuthCredentials oAuthCredentials) {
        embeddedAssistantStub = embeddedAssistantStub.withCallCredentials(getCallCredentials(oAuthCredentials));
    }

    /**
     * Calling text query or audio assistant based on params
     * @param request the request for the assistant (text or voice)
     * @throws ConverseException
     */
    public void requestAssistant(byte[] request) throws ConverseException {
        switch (ioConf.getInputMode()) {
            case IoConf.AUDIO:
                audioResponse = audioRequestAssistant(request);
                break;
            case IoConf.TEXT:
                audioResponse = textRequestAssistant(request);
                break;
            default:
                LOGGER.error("Unknown input mode {}", ioConf.getInputMode());
        }
    }

    /**
     * Handle text query
     * @param request byte[]
     * @return byte[]
     * @throws ConverseException
     */
    private byte[] textRequestAssistant(byte[] request) throws ConverseException {
        this.textQuery = new String(request);
        try {
            currentResponse = new ByteArrayOutputStream();
            finishLatch = new CountDownLatch(1);
            // Send the config request
            StreamObserver<AssistRequest> requester = embeddedAssistantStub.assist(this);
            requester.onNext(getConfigRequest());

            LOGGER.info("Requesting the assistant");

            // Mark the end of requests
            requester.onCompleted();

            // Receiving happens asynchronously
            finishLatch.await(1, TimeUnit.MINUTES);

            return currentResponse.toByteArray();
        } catch (Exception e) {
            throw new ConverseException("Error requesting the assistant", e);
        }
    }

    public byte[] getAudioResponse() {
        return audioResponse;
    }

    public String getTextResponse() {
        return textResponse;
    }

    /**
     * Handle audio request
     * @param request byte[]
     * @return byte[]
     * @throws ConverseException
     */
    private byte[] audioRequestAssistant(byte[] request) throws ConverseException {
        try {
            // Reset the byte array
            currentResponse = new ByteArrayOutputStream();
            finishLatch = new CountDownLatch(1);

            LOGGER.info("Requesting the assistant");
            // Send the config request
            StreamObserver<AssistRequest> requester = embeddedAssistantStub.assist(this);
            requester.onNext(getConfigRequest());

            // Divide the audio request into chunks
            byte[][] chunks = divideArray(request, assistantConf.getChunkSize());

            // Send a request for each chunk
            for (byte[] chunk : chunks) {
                ByteString audioIn = ByteString.copyFrom(chunk);

                // Chunk of the request
                AssistRequest assistRequest = AssistRequest
                        .newBuilder()
                        .setAudioIn(audioIn)
                        .build();

                requester.onNext(assistRequest);
            }

            // Mark the end of requests
            requester.onCompleted();

            // Receiving happens asynchronously
            finishLatch.await(1, TimeUnit.MINUTES);

            return currentResponse.toByteArray();
        } catch (Exception e) {
            throw new ConverseException("Error requesting the assistant", e);
        }
    }

    @Override
    public void onNext(AssistResponse value) {
        try {
            if (value.getEventType() != null
                    && value.getEventType() != AssistResponse.EventType.EVENT_TYPE_UNSPECIFIED) {

                LOGGER.info("Event type : {}", value.getEventType().name());
            }

            if (value.getAudioOut() != null) {
                currentResponse.write(value.getAudioOut().getAudioData().toByteArray());
            }

            if (value.getDialogStateOut() != null) {
                currentConversationState = value.getDialogStateOut().getConversationState();

                if (value.getSpeechResultsList() != null) {
                    String userRequest = value.getSpeechResultsList().stream()
                            .map(SpeechRecognitionResult::getTranscript)
                            .collect(Collectors.joining(" "));

                    if (!userRequest.isEmpty()) {
                        LOGGER.info("Request Text : {}", userRequest);
                    }
                }
            }

            if (value.getDialogStateOut() != null
                    && value.getDialogStateOut().getSupplementalDisplayText() != null
                    && !value.getDialogStateOut().getSupplementalDisplayText().isEmpty()) {

                // Capturing string response for text query output
                this.textResponse = value.getDialogStateOut().getSupplementalDisplayText();
            }

        } catch (Exception e) {
            LOGGER.warn("Error requesting the assistant", e);
        }
    }

    @Override
    public void onError(Throwable t) {
        LOGGER.warn("Error requesting the assistant", t);
        finishLatch.countDown();
    }

    @Override
    public void onCompleted() {
        LOGGER.info("End of the response");
        finishLatch.countDown();
    }

    /**
     * Create the config message, this message must be send before the audio for each request
     *
     * @return the request to send
     */
    private AssistRequest getConfigRequest() {
        AudioInConfig audioInConfig = AudioInConfig
                .newBuilder()
                .setEncoding(AudioInConfig.Encoding.LINEAR16)
                .setSampleRateHertz(assistantConf.getAudioSampleRate())
                .build();

        AudioOutConfig audioOutConfig = AudioOutConfig
                .newBuilder()
                .setEncoding(AudioOutConfig.Encoding.LINEAR16)
                .setSampleRateHertz(assistantConf.getAudioSampleRate())
                .setVolumePercentage(assistantConf.getVolumePercent())
                .build();

        DialogStateIn.Builder dialogStateInBuilder = DialogStateIn
                .newBuilder()
                // We set the us local as default
                .setLanguageCode("en-UK")
                .setConversationState(currentConversationState);

        DeviceConfig deviceConfig = DeviceConfig
                .newBuilder()
                .setDeviceModelId(deviceModel.getDeviceModelId())
                .setDeviceId(device.getId())
                .build();

        AssistConfig.Builder assistConfigBuilder = AssistConfig
                .newBuilder()
                .setDialogStateIn(dialogStateInBuilder.build())
                .setDeviceConfig(deviceConfig)
                .setAudioInConfig(audioInConfig)
                .setAudioOutConfig(audioOutConfig);

        // Preparing AssistantConfig based on type of input. ie audio or text
        assistConfigBuilder = getAssistConfigBuilder(
                assistConfigBuilder, audioInConfig, textQuery
        );

        return AssistRequest
                .newBuilder()
                .setConfig(assistConfigBuilder.build())
                .build();
    }

    /**
     * Prepares AssistConfig based on input type
     * @param assistConfigBuilder AssistConfig.Builder
     * @param audioConfig AudioInConfig
     * @param text_query String
     * @return AssistConfig.Builder
     */
    private AssistConfig.Builder getAssistConfigBuilder(
            AssistConfig.Builder assistConfigBuilder,
            AudioInConfig audioConfig,
            String text_query
    ) {
        switch (ioConf.getInputMode()) {
            case IoConf.AUDIO:
                return assistConfigBuilder.setAudioInConfig(audioConfig);
            case IoConf.TEXT:
                return assistConfigBuilder.setTextQuery(text_query);
            default:
                LOGGER.error("Unknown input mode {}", ioConf.getInputMode());
                return assistConfigBuilder;
        }

    }

    /**
     * Divide an array of byte in chunks of chunkSize bytes
     *
     * @param source    the source byte array
     * @param chunkSize the size of a chunk
     * @return an array of chunks
     * @see <a href="http://stackoverflow.com/questions/3405195/divide-array-into-smaller-parts">Divide array into smaller parts</a>
     */
    private byte[][] divideArray(byte[] source, int chunkSize) {
        byte[][] ret = new byte[(int) Math.ceil(source.length / (double) chunkSize)][chunkSize];

        int start = 0;

        for (int i = 0; i < ret.length; i++) {
            ret[i] = Arrays.copyOfRange(source, start, start + chunkSize);
            start += chunkSize;
        }

        return ret;
    }
}