package org.whispersystems.signalservice.internal.websocket; import com.google.protobuf.InvalidProtocolBufferException; import org.whispersystems.libsignal.logging.Log; import org.whispersystems.libsignal.util.Pair; import org.whispersystems.libsignal.util.guava.Optional; import org.whispersystems.signalservice.api.push.TrustStore; import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.SleepTimer; import org.whispersystems.signalservice.api.util.Tls12SocketFactory; import org.whispersystems.signalservice.api.websocket.ConnectivityListener; import org.whispersystems.signalservice.internal.util.BlacklistingTrustManager; import org.whispersystems.signalservice.internal.util.Util; import org.whispersystems.signalservice.internal.util.concurrent.SettableFuture; import java.io.IOException; import java.security.KeyManagementException; import java.security.NoSuchAlgorithmException; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import okhttp3.ConnectionSpec; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; import okhttp3.WebSocket; import okhttp3.WebSocketListener; import okio.ByteString; import static org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketMessage; import static org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketRequestMessage; import static org.whispersystems.signalservice.internal.websocket.WebSocketProtos.WebSocketResponseMessage; public class WebSocketConnection extends WebSocketListener { private static final String TAG = WebSocketConnection.class.getSimpleName(); private static final int KEEPALIVE_TIMEOUT_SECONDS = 55; private final LinkedList<WebSocketRequestMessage> incomingRequests = new LinkedList<>(); private final Map<Long, SettableFuture<Pair<Integer, String>>> outgoingRequests = new HashMap<>(); private final String wsUri; private final TrustStore trustStore; private final Optional<CredentialsProvider> credentialsProvider; private final String userAgent; private final ConnectivityListener listener; private final SleepTimer sleepTimer; private WebSocket client; private KeepAliveSender keepAliveSender; private int attempts; private boolean connected; public WebSocketConnection(String httpUri, TrustStore trustStore, Optional<CredentialsProvider> credentialsProvider, String userAgent, ConnectivityListener listener, SleepTimer timer) { this.trustStore = trustStore; this.credentialsProvider = credentialsProvider; this.userAgent = userAgent; this.listener = listener; this.sleepTimer = timer; this.attempts = 0; this.connected = false; String uri = httpUri.replace("https://", "wss://").replace("http://", "ws://"); if (credentialsProvider.isPresent()) this.wsUri = uri + "/v1/websocket/?login=%s&password=%s"; else this.wsUri = uri + "/v1/websocket/"; } public synchronized void connect() { Log.w(TAG, "WSC connect()..."); if (client == null) { String filledUri; if (credentialsProvider.isPresent()) { String identifier = credentialsProvider.get().getUuid() != null ? credentialsProvider.get().getUuid().toString() : credentialsProvider.get().getE164(); filledUri = String.format(wsUri, identifier, credentialsProvider.get().getPassword()); } else { filledUri = wsUri; } Pair<SSLSocketFactory, X509TrustManager> socketFactory = createTlsSocketFactory(trustStore); OkHttpClient okHttpClient = new OkHttpClient.Builder() .sslSocketFactory(new Tls12SocketFactory(socketFactory.first()), socketFactory.second()) .connectionSpecs(Util.immutableList(ConnectionSpec.RESTRICTED_TLS)) .readTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS) .connectTimeout(KEEPALIVE_TIMEOUT_SECONDS + 10, TimeUnit.SECONDS) .build(); Request.Builder requestBuilder = new Request.Builder().url(filledUri); if (userAgent != null) { requestBuilder.addHeader("X-Signal-Agent", userAgent); } if (listener != null) { listener.onConnecting(); } this.connected = false; this.client = okHttpClient.newWebSocket(requestBuilder.build(), this); } } public synchronized void disconnect() { Log.w(TAG, "WSC disconnect()..."); if (client != null) { client.close(1000, "OK"); client = null; connected = false; } if (keepAliveSender != null) { keepAliveSender.shutdown(); keepAliveSender = null; } } public synchronized WebSocketRequestMessage readRequest(long timeoutMillis) throws TimeoutException, IOException { if (client == null) { throw new IOException("Connection closed!"); } long startTime = System.currentTimeMillis(); while (client != null && incomingRequests.isEmpty() && elapsedTime(startTime) < timeoutMillis) { Util.wait(this, Math.max(1, timeoutMillis - elapsedTime(startTime))); } if (incomingRequests.isEmpty() && client == null) throw new IOException("Connection closed!"); else if (incomingRequests.isEmpty()) throw new TimeoutException("Timeout exceeded"); else return incomingRequests.removeFirst(); } public synchronized Future<Pair<Integer, String>> sendRequest(WebSocketRequestMessage request) throws IOException { if (client == null || !connected) throw new IOException("No connection!"); WebSocketMessage message = WebSocketMessage.newBuilder() .setType(WebSocketMessage.Type.REQUEST) .setRequest(request) .build(); SettableFuture<Pair<Integer, String>> future = new SettableFuture<>(); outgoingRequests.put(request.getId(), future); if (!client.send(ByteString.of(message.toByteArray()))) { throw new IOException("Write failed!"); } return future; } public synchronized void sendResponse(WebSocketResponseMessage response) throws IOException { if (client == null) { throw new IOException("Connection closed!"); } WebSocketMessage message = WebSocketMessage.newBuilder() .setType(WebSocketMessage.Type.RESPONSE) .setResponse(response) .build(); if (!client.send(ByteString.of(message.toByteArray()))) { throw new IOException("Write failed!"); } } private synchronized void sendKeepAlive() throws IOException { if (keepAliveSender != null && client != null) { byte[] message = WebSocketMessage.newBuilder() .setType(WebSocketMessage.Type.REQUEST) .setRequest(WebSocketRequestMessage.newBuilder() .setId(System.currentTimeMillis()) .setPath("/v1/keepalive") .setVerb("GET") .build()).build() .toByteArray(); if (!client.send(ByteString.of(message))) { throw new IOException("Write failed!"); } } } @Override public synchronized void onOpen(WebSocket webSocket, Response response) { if (client != null && keepAliveSender == null) { Log.w(TAG, "onConnected()"); attempts = 0; connected = true; keepAliveSender = new KeepAliveSender(); keepAliveSender.start(); if (listener != null) listener.onConnected(); } } @Override public synchronized void onMessage(WebSocket webSocket, ByteString payload) { Log.w(TAG, "WSC onMessage()"); try { WebSocketMessage message = WebSocketMessage.parseFrom(payload.toByteArray()); Log.w(TAG, "Message Type: " + message.getType().getNumber()); if (message.getType().getNumber() == WebSocketMessage.Type.REQUEST_VALUE) { incomingRequests.add(message.getRequest()); } else if (message.getType().getNumber() == WebSocketMessage.Type.RESPONSE_VALUE) { SettableFuture<Pair<Integer, String>> listener = outgoingRequests.get(message.getResponse().getId()); if (listener != null) listener.set(new Pair<>(message.getResponse().getStatus(), new String(message.getResponse().getBody().toByteArray()))); } notifyAll(); } catch (InvalidProtocolBufferException e) { Log.w(TAG, e); } } @Override public synchronized void onClosed(WebSocket webSocket, int code, String reason) { Log.w(TAG, "onClose()..."); this.connected = false; Iterator<Map.Entry<Long, SettableFuture<Pair<Integer, String>>>> iterator = outgoingRequests.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry<Long, SettableFuture<Pair<Integer, String>>> entry = iterator.next(); entry.getValue().setException(new IOException("Closed: " + code + ", " + reason)); iterator.remove(); } if (keepAliveSender != null) { keepAliveSender.shutdown(); keepAliveSender = null; } if (listener != null) { listener.onDisconnected(); } Util.wait(this, Math.min(++attempts * 200, TimeUnit.SECONDS.toMillis(15))); if (client != null) { client.close(1000, "OK"); client = null; connected = false; connect(); } notifyAll(); } @Override public synchronized void onFailure(WebSocket webSocket, Throwable t, Response response) { Log.w(TAG, "onFailure()"); Log.w(TAG, t); if (response != null && (response.code() == 401 || response.code() == 403)) { if (listener != null) listener.onAuthenticationFailure(); } if (client != null) { onClosed(webSocket, 1000, "OK"); } } @Override public void onMessage(WebSocket webSocket, String text) { Log.w(TAG, "onMessage(text)! " + text); } @Override public synchronized void onClosing(WebSocket webSocket, int code, String reason) { Log.w(TAG, "onClosing()!..."); webSocket.close(1000, "OK"); } private long elapsedTime(long startTime) { return System.currentTimeMillis() - startTime; } private Pair<SSLSocketFactory, X509TrustManager> createTlsSocketFactory(TrustStore trustStore) { try { SSLContext context = SSLContext.getInstance("TLS"); TrustManager[] trustManagers = BlacklistingTrustManager.createFor(trustStore); context.init(null, trustManagers, null); return new Pair<>(context.getSocketFactory(), (X509TrustManager)trustManagers[0]); } catch (NoSuchAlgorithmException | KeyManagementException e) { throw new AssertionError(e); } } private class KeepAliveSender extends Thread { private AtomicBoolean stop = new AtomicBoolean(false); public void run() { while (!stop.get()) { try { sleepTimer.sleep(TimeUnit.SECONDS.toMillis(KEEPALIVE_TIMEOUT_SECONDS)); Log.w(TAG, "Sending keep alive..."); sendKeepAlive(); } catch (Throwable e) { Log.w(TAG, e); } } } public void shutdown() { stop.set(true); } } }