package de.cronn.proxy.ssh; import java.io.Closeable; import java.util.ArrayDeque; import java.util.Deque; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; import de.cronn.proxy.ssh.util.Assert; public class SshProxy implements Closeable { private static final Logger log = LoggerFactory.getLogger(SshProxy.class); public static final String LOCALHOST = "localhost"; private static final int DEFAULT_TIMEOUT_MILLIS = 10_000; private final Deque<Session> sshSessions = new ArrayDeque<>(); private final Map<Session, Set<Integer>> portForwardings = new LinkedHashMap<>(); private final SshConfiguration sshConfiguration; private int timeoutMillis; public SshProxy() { this(DEFAULT_TIMEOUT_MILLIS); } public SshProxy(int timeoutMillis) { try { sshConfiguration = SshConfiguration.getConfiguration(); } catch (Exception e) { throw new SshProxyRuntimeException("Failed to open SSH proxy", e); } this.timeoutMillis = timeoutMillis; } public int connect(String sshTunnelHost, String host, int port) { return connect(sshTunnelHost, host, port, 0); } public int connect(String sshTunnelHost, String host, int port, int localPort) { Assert.notNull(sshTunnelHost, "sshTunnelHost must not be null"); Assert.notNull(host, "host must not be null"); Assert.isTrue(port > 0, "illegal port: " + port); Assert.isTrue(localPort >= 0, "illegal local port: " + localPort); log.debug("tunneling to {}:{} via {}", host, port, sshTunnelHost); try { sshConfiguration.addIdentity(sshTunnelHost); SshProxyConfig proxyConfig = sshConfiguration.getProxyConfiguration(sshTunnelHost); if (proxyConfig == null) { return directConnect(sshTunnelHost, host, port, localPort); } int jumpPort = connect(proxyConfig); String hostUser = sshConfiguration.getHostUser(sshTunnelHost); String jumpHost = proxyConfig.getJumpHost(); Session jumpHostSession = sshConfiguration.openSession(hostUser, jumpHost, jumpPort); String hostname = sshConfiguration.getHostName(sshTunnelHost); jumpHostSession.setHostKeyAlias(hostname); sshSessions.push(jumpHostSession); jumpHostSession.setTimeout(timeoutMillis); jumpHostSession.connect(timeoutMillis); log.debug("[{}] connected via {}@localhost:{}", sshTunnelHost, hostUser, jumpPort); return addLocalPortForwarding(sshTunnelHost, jumpHostSession, host, port, localPort); } catch (Exception e) { throw new SshProxyRuntimeException("Failed to create SSH tunnel to " + host + " via " + sshTunnelHost, e); } } private int connect(SshProxyConfig proxyConfig) { String jumpHost = proxyConfig.getJumpHost(); String forwardingHost = proxyConfig.getForwardingHost(); int forwardingPort = proxyConfig.getForwardingPort(); return connect(jumpHost, forwardingHost, forwardingPort); } private int directConnect(String jumpHost, String targetHost, int targetPort, int localPort) throws JSchException { Session jumpHostSession = sshConfiguration.openSession(jumpHost); sshSessions.add(jumpHostSession); jumpHostSession.setTimeout(timeoutMillis); try { jumpHostSession.connect(timeoutMillis); } catch (JSchException e) { log.debug("Failed to connect to {} via {}", targetHost, jumpHost, e); throw new SshProxyRuntimeException("Failed to connect to " + targetHost + " via " + jumpHost); } log.debug("[{}] connected", jumpHost); return addLocalPortForwarding(jumpHost, jumpHostSession, targetHost, targetPort, localPort); } private int addLocalPortForwarding(String sshTunnelHost, Session session, String targetHost, int targetPort, int localPort) throws JSchException { int localPortReturned = session.setPortForwardingL(localPort, targetHost, targetPort); log.debug("[{}] local port {} forwarded to {}:{}", sshTunnelHost, localPortReturned, targetHost, targetPort); Set<Integer> ports = portForwardings.computeIfAbsent(session, k -> new LinkedHashSet<>()); ports.add(Integer.valueOf(localPortReturned)); return localPortReturned; } @Override public void close() { if (!sshSessions.isEmpty()) { log.debug("closing SSH sessions"); } while (!sshSessions.isEmpty()) { Session session = sshSessions.pop(); deletePortForwarding(session); try { session.disconnect(); } catch (Exception e) { log.error("Failed to disconnect SSH session", e); } } Assert.isTrue(portForwardings.isEmpty(), "port forwardings must be empty at this point"); } private void deletePortForwarding(Session session) { Set<Integer> ports = portForwardings.remove(session); if (ports != null) { for (Integer localPort : ports) { deletePortForwarding(session, localPort); } } } private void deletePortForwarding(Session session, Integer localPort) { try { String host = session.getHost(); if (host.equals(LOCALHOST)) { host = session.getHostKeyAlias(); } session.delPortForwardingL(LOCALHOST, localPort.intValue()); log.debug("deleted local port forwarding on port {} for {}", localPort, host); } catch (Exception e) { log.error("failed to delete port forwarding of port {}", localPort, e); } } }