/*
 * Copyright 2019 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.google.audio.asr.cloud;

import android.support.annotation.GuardedBy;
import com.google.audio.asr.CloudSpeechSessionParams;
import com.google.audio.asr.SpeechSession;
import com.google.audio.asr.SpeechSessionFactory;
import com.google.audio.asr.SpeechSessionListener;
import com.google.common.flogger.FluentLogger;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.stub.MetadataUtils;
import java.util.concurrent.TimeUnit;
import org.joda.time.Duration;

/** A factory for creating cloud sessions. */
public class CloudSpeechSessionFactory implements SpeechSessionFactory {
  private static final FluentLogger logger = FluentLogger.forEnclosingClass();
  private static final String SERVICE_URL = "speech.googleapis.com";
  private static final String HEADER_API_KEY = "X-Goog-Api-Key";

  /** Wait 1 second for the preexisting calls to finish. */
  private static final Duration TERMINATE_CHANNEL_DURATION = Duration.standardSeconds(1);

  /** Lock for handling concurrent accesses to the `params` variable. */
  private final Object paramsLock = new Object();

  @GuardedBy("paramsLock")
  private CloudSpeechSessionParams params;
  private String apiKey;

  private ManagedChannel channel;

  public CloudSpeechSessionFactory(CloudSpeechSessionParams params, String apiKey) {
    this.params = params;
    this.apiKey = apiKey;
  }

  @Override
  public SpeechSession create(SpeechSessionListener listener, int sampleRateHz) {
    if (this.channel == null) {
      this.channel = createManagedChannel(apiKey);
    } else {
      ensureManagedChannelConnection();
    }
    synchronized (paramsLock) {
      return new CloudSpeechSession(params, listener, sampleRateHz, channel);
    }
  }

  @Override
  public void cleanup() {
    if (channel != null) {
      channel.shutdown();
      try {
        if (!channel.awaitTermination(
            TERMINATE_CHANNEL_DURATION.getStandardSeconds(), TimeUnit.SECONDS)) {
          channel.shutdownNow();
        }
      } catch (InterruptedException e) {
        logger.atWarning().withCause(e).log("Channel termination failed.");
      }
      channel = null;
    }
  }

  public void setParams(CloudSpeechSessionParams params) {
    synchronized (paramsLock) {
      this.params = params;
    }
  }

  protected void ensureManagedChannelConnection() {
    // The channel may stuck at the TRANSIENT_FAILURE state, if so, enter idle to let channel to
    // trigger creation of a new connection.
    if (ConnectivityState.TRANSIENT_FAILURE.equals(channel.getState(false))) {
      logger.atInfo().log("ManagedChannel was in TRANSIENT_FAILURE state.");
      channel.enterIdle();
    }
  }

  private ManagedChannel createManagedChannel(String apiKey) {
    Metadata metadata = new Metadata();
    metadata.put(Metadata.Key.of(HEADER_API_KEY, Metadata.ASCII_STRING_MARSHALLER), apiKey);
    return ManagedChannelBuilder.forTarget(SERVICE_URL)
        .intercept(MetadataUtils.newAttachHeadersInterceptor(metadata))
        .build();
  }
}