/* * Copyright (C) 2016-2017 Neo Visionaries Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.neovisionaries.ws.client; import java.io.IOException; import java.net.Inet6Address; import java.net.InetAddress; import java.net.Socket; import java.net.UnknownHostException; import java.util.Arrays; import java.util.Comparator; import javax.net.SocketFactory; import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; /** * A class to connect to the server. * * @since 1.20 * * @author Takahiko Kawasaki */ class SocketConnector { private final SocketFactory mSocketFactory; private final Address mAddress; private final int mConnectionTimeout; private final String[] mServerNames; private final ProxyHandshaker mProxyHandshaker; private final SSLSocketFactory mSSLSocketFactory; private final String mHost; private final int mPort; private DualStackMode mDualStackMode = DualStackMode.BOTH; private int mDualStackFallbackDelay = 250; private boolean mVerifyHostname; private Socket mSocket; SocketConnector(SocketFactory socketFactory, Address address, int timeout, String[] serverNames) { this(socketFactory, address, timeout, serverNames, null, null, null, 0); } SocketConnector( SocketFactory socketFactory, Address address, int timeout, String[] serverNames, ProxyHandshaker handshaker, SSLSocketFactory sslSocketFactory, String host, int port) { mSocketFactory = socketFactory; mAddress = address; mConnectionTimeout = timeout; mServerNames = serverNames; mProxyHandshaker = handshaker; mSSLSocketFactory = sslSocketFactory; mHost = host; mPort = port; } public int getConnectionTimeout() { return mConnectionTimeout; } public Socket getSocket() { return mSocket; } public Socket getConnectedSocket() throws WebSocketException { // Connect lazily. if (mSocket == null) { connectSocket(); } return mSocket; } private void connectSocket() throws WebSocketException { // Create socket initiator. SocketInitiator socketInitiator = new SocketInitiator( mSocketFactory, mAddress, mConnectionTimeout, mServerNames, mDualStackMode, mDualStackFallbackDelay); // Resolve hostname to IP addresses InetAddress[] addresses = resolveHostname(); // Let the sockets race until one has been established, following // RFC 6555 (*happy eyeballs*). try { mSocket = socketInitiator.establish(addresses); } catch (Exception e) { // True if a proxy server is set. boolean proxied = mProxyHandshaker != null; // Failed to connect the server. String message = String.format("Failed to connect to %s'%s': %s", (proxied ? "the proxy " : ""), mAddress, e.getMessage()); // Raise an exception with SOCKET_CONNECT_ERROR. throw new WebSocketException(WebSocketError.SOCKET_CONNECT_ERROR, message, e); } } private InetAddress[] resolveHostname() throws WebSocketException { InetAddress[] addresses = null; UnknownHostException exception = null; try { // Resolve hostname to IP addresses. addresses = InetAddress.getAllByName(mAddress.getHostname()); // Sort addresses: IPv6 first, then IPv4. Arrays.sort(addresses, new Comparator<InetAddress>() { public int compare(InetAddress left, InetAddress right) { if (left.getClass() == right.getClass()) { return 0; } if (left instanceof Inet6Address) { return -1; } else { return 1; } } }); } catch (UnknownHostException e) { exception = e; } // Return the ordered IP addresses (if any), otherwise raise the exception. if (addresses != null && addresses.length > 0) { return addresses; } if (exception == null) { exception = new UnknownHostException("No IP addresses found"); } // Failed to resolve hostname to IP address. String message = String.format("Failed to resolve hostname %s: %s", mAddress, exception.getMessage()); // Raise an exception with SOCKET_CONNECT_ERROR. throw new WebSocketException(WebSocketError.SOCKET_CONNECT_ERROR, message, exception); } public Socket connect() throws WebSocketException { try { // Connect to the server (either a proxy or a WebSocket endpoint). doConnect(); assert mSocket != null; return mSocket; } catch (WebSocketException e) { // Failed to connect the server. if (mSocket != null) { try { // Close the socket. mSocket.close(); } catch (IOException ioe) { // Ignore any error raised by close(). } } throw e; } } SocketConnector setDualStackSettings(DualStackMode mode, int fallbackDelay) { mDualStackMode = mode; mDualStackFallbackDelay = fallbackDelay; return this; } SocketConnector setVerifyHostname(boolean verifyHostname) { mVerifyHostname = verifyHostname; return this; } private void doConnect() throws WebSocketException { // True if a proxy server is set. boolean proxied = mProxyHandshaker != null; // Establish a socket associated to one of the resolved IP addresses connectSocket(); assert mSocket != null; if (mSocket instanceof SSLSocket) { // Verify that the hostname matches the certificate here since // this is not automatically done by the SSLSocket. verifyHostname((SSLSocket)mSocket, mAddress.getHostname()); } // If a proxy server is set. if (proxied) { // Perform handshake with the proxy server. // SSL handshake is performed as necessary, too. handshake(); } } private void verifyHostname(SSLSocket socket, String hostname) throws HostnameUnverifiedException { if (mVerifyHostname == false) { // Skip hostname verification. return; } // Hostname verifier. OkHostnameVerifier verifier = OkHostnameVerifier.INSTANCE; // The SSL session. SSLSession session = socket.getSession(); // Verify the hostname. if (verifier.verify(hostname, session)) { // Verified. No problem. return; } // The certificate of the peer does not match the expected hostname. throw new HostnameUnverifiedException(socket, hostname); } /** * Perform proxy handshake and optionally SSL handshake. */ private void handshake() throws WebSocketException { // Sanity check assert mSocket != null; try { // Perform handshake with the proxy server. mProxyHandshaker.perform(mSocket); } catch (IOException e) { // Handshake with the proxy server failed. String message = String.format( "Handshake with the proxy server (%s) failed: %s", mAddress, e.getMessage()); // Raise an exception with PROXY_HANDSHAKE_ERROR. throw new WebSocketException(WebSocketError.PROXY_HANDSHAKE_ERROR, message, e); } if (mSSLSocketFactory == null) { // SSL handshake with the WebSocket endpoint is not needed. return; } try { // Overlay the existing socket. mSocket = mSSLSocketFactory.createSocket(mSocket, mHost, mPort, true); } catch (IOException e) { // Failed to overlay an existing socket. String message = "Failed to overlay an existing socket: " + e.getMessage(); // Raise an exception with SOCKET_OVERLAY_ERROR. throw new WebSocketException(WebSocketError.SOCKET_OVERLAY_ERROR, message, e); } try { // Start the SSL handshake manually. As for the reason, see // http://docs.oracle.com/javase/7/docs/technotes/guides/security/jsse/samples/sockets/client/SSLSocketClient.java ((SSLSocket)mSocket).startHandshake(); // Verify that the proxied hostname matches the certificate here since // this is not automatically done by the SSLSocket. verifyHostname((SSLSocket)mSocket, mProxyHandshaker.getProxiedHostname()); } catch (IOException e) { // SSL handshake with the WebSocket endpoint failed. String message = String.format( "SSL handshake with the WebSocket endpoint (%s) failed: %s", mAddress, e.getMessage()); // Raise an exception with SSL_HANDSHAKE_ERROR. throw new WebSocketException(WebSocketError.SSL_HANDSHAKE_ERROR, message, e); } } void closeSilently() { if (mSocket != null) { try { mSocket.close(); } catch (Throwable t) { // Ignored. } } } }