package de.cronn.proxy.ssh; import static org.junit.Assert.*; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.net.Socket; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.Arrays; import org.apache.sshd.common.config.keys.KeyUtils; import org.apache.sshd.common.util.security.SecurityUtils; import org.apache.sshd.server.SshServer; import org.apache.sshd.server.auth.pubkey.AcceptAllPublickeyAuthenticator; import org.apache.sshd.server.forward.AcceptAllForwardingFilter; import org.apache.sshd.server.keyprovider.AbstractGeneratorHostKeyProvider; import org.hamcrest.CoreMatchers; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SshProxyTest { private static final Logger log = LoggerFactory.getLogger(SshProxyTest.class); private static final Charset CONFIG_CHARSET = StandardCharsets.ISO_8859_1; private static final Charset TRANSFER_CHARSET = StandardCharsets.UTF_16; private static final String KNOWN_HOSTS_FILENAME = "known_hosts"; private static final String CONFIG_FILENAME = "config"; private static final Path TEST_RESOURCES = Paths.get("src", "test", "resources"); private static final Path SERVER_RSA_KEY = TEST_RESOURCES.resolve("server-rsa.key"); private static final Path SERVER_ECDSA_KEY = TEST_RESOURCES.resolve("server-ecdsa.key"); private static final long TEST_TIMEOUT_MILLIS = 30_000L; @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); private String oldUserHome; private Path dotSsh; private static final String TEST_TEXT = "Hello World"; @Before public void checkBouncyCastleIsRegistered() { assertTrue("BouncyCastle is registered", SecurityUtils.isBouncyCastleRegistered()); } @Before public void setUp() throws Exception { Path userHome = temporaryFolder.getRoot().toPath(); oldUserHome = System.getProperty("user.home"); System.setProperty("user.home", userHome.toAbsolutePath().toString()); log.debug("changed 'user.home' to {}", System.getProperty("user.home")); dotSsh = userHome.resolve(".ssh"); Files.createDirectories(dotSsh); for (String file : Arrays.asList("id_rsa", "id_rsa.pub")) { Files.copy(TEST_RESOURCES.resolve(file), dotSsh.resolve(file)); } appendToSshFile(CONFIG_FILENAME, ""); appendToSshFile(KNOWN_HOSTS_FILENAME, ""); } @After public void tearDown() { System.setProperty("user.home", oldUserHome); } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop() throws Exception { SshServer sshServer = setUpSshServer(); int sshServerPort = sshServer.getPort(); String hostConfigName = "localhost-" + sshServerPort; appendToSshFile(CONFIG_FILENAME, "Host " + hostConfigName + "\n\tHostName localhost\n\tPort " + sshServerPort + "\n\n"); try (DummyServerSocketThread dummyServerSocketThread = new DummyServerSocketThread(TRANSFER_CHARSET, TEST_TEXT); SshProxy sshProxy = new SshProxy()) { int port = sshProxy.connect(hostConfigName, "localhost", dummyServerSocketThread.getPort()); final String receivedText; try (Socket s = new Socket(SshProxy.LOCALHOST, port); InputStream is = s.getInputStream()) { log.info("connected to port: {}", port); receivedText = readLine(is); } assertEquals(TEST_TEXT, receivedText); } finally { tryStop(sshServer); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop_EcDsaServer() throws Exception { SshServer sshServer = setUpSshServer(KeyUtils.EC_ALGORITHM); int sshServerPort = sshServer.getPort(); String hostConfigName = "localhost-" + sshServerPort; appendToSshFile(CONFIG_FILENAME, "Host " + hostConfigName + "\n\tHostName localhost\n\tPort " + sshServerPort + "\n\n"); try (DummyServerSocketThread dummyServerSocketThread = new DummyServerSocketThread(TRANSFER_CHARSET, TEST_TEXT); SshProxy sshProxy = new SshProxy()) { int port = sshProxy.connect(hostConfigName, "localhost", dummyServerSocketThread.getPort()); final String receivedText; try (Socket s = new Socket(SshProxy.LOCALHOST, port); InputStream is = s.getInputStream()) { log.info("connected to port: {}", port); receivedText = readLine(is); } assertEquals(TEST_TEXT, receivedText); } finally { tryStop(sshServer); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHopWithLocalPort() throws Exception { SshServer sshServer = setUpSshServer(); int sshServerPort = sshServer.getPort(); String hostConfigName = "localhost-" + sshServerPort; appendToSshFile(CONFIG_FILENAME, "Host " + hostConfigName + "\n\tHostName localhost\n\tPort " + sshServerPort + "\n\n"); try (DummyServerSocketThread dummyServerSocketThread = new DummyServerSocketThread(TRANSFER_CHARSET, TEST_TEXT); SshProxy sshProxy = new SshProxy()) { int port = sshProxy.connect(hostConfigName, "localhost", dummyServerSocketThread.getPort(), 2345); final String receivedText; try (Socket s = new Socket(SshProxy.LOCALHOST, port); InputStream is = s.getInputStream()) { log.info("connected to port: {}", port); receivedText = readLine(is); } assertEquals(TEST_TEXT, receivedText); } finally { tryStop(sshServer); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testTwoHops_ProxyCommand() throws Exception { doTestTwoHops("ProxyCommand ssh -q -W %h:%p firsthop"); } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testTwoHops_ProxyJump() throws Exception { doTestTwoHops("ProxyJump firsthop"); } private void doTestTwoHops(String proxyConfiguration) throws Exception { SshServer firstSshServer = setUpSshServer(); int firstServerPort = firstSshServer.getPort(); SshServer secondSshServer = setUpSshServer(); int secondServerPort = secondSshServer.getPort(); appendToSshFile(CONFIG_FILENAME, "Host firsthop\n\tHostName localhost\n\tPort " + firstServerPort + "\n\n"); appendToSshFile(CONFIG_FILENAME, "Host secondhop\n\tHostName localhost\n\tPort " + secondServerPort + "\n\t" + proxyConfiguration + "\n\n"); try (DummyServerSocketThread dummyServerSocketThread = new DummyServerSocketThread(TRANSFER_CHARSET, TEST_TEXT); SshProxy sshProxy = new SshProxy()) { int port = sshProxy.connect("secondhop", "localhost", dummyServerSocketThread.getPort()); final String receivedText; try (Socket s = new Socket(SshProxy.LOCALHOST, port); InputStream is = s.getInputStream()) { log.info("connected to port: {}", port); receivedText = readLine(is); } assertEquals(TEST_TEXT, receivedText); } finally { tryStop(firstSshServer); tryStop(secondSshServer); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop_NoHostKeyFound() throws Exception { try (SshProxy sshProxy = new SshProxy()) { sshProxy.connect("jumphost", "targethost", 1234); fail("SshProxyRuntimeException expected"); } catch (SshProxyRuntimeException e) { log.debug("Expected exception", e); assertEquals("Failed to create SSH tunnel to targethost via jumphost", e.getMessage()); assertThat(e.getCause().getMessage(), CoreMatchers.startsWith("Found no host key for jumphost")); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop_ConnectionRefused() throws Exception { SshServer sshServer = setUpSshServer(); sshServer.stop(); try (SshProxy sshProxy = new SshProxy()) { sshProxy.connect("localhost", "targethost", 1234); fail("SshProxyRuntimeException expected"); } catch (SshProxyRuntimeException e) { log.debug("Expected exception", e); assertEquals("Failed to create SSH tunnel to targethost via localhost", e.getMessage()); assertEquals("Failed to connect to targethost via localhost", e.getCause().getMessage()); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop_IllegalPort() throws Exception { try (SshProxy sshProxy = new SshProxy()) { sshProxy.connect("localhost", "targethost", 0); fail("IllegalArgumentException expected"); } catch (IllegalArgumentException e) { assertEquals("illegal port: 0", e.getMessage()); } } @Test(timeout = TEST_TIMEOUT_MILLIS) public void testSingleHop_IllegalLocalPort() throws Exception { try (SshProxy sshProxy = new SshProxy()) { sshProxy.connect("localhost", "targethost", 1234, -1); fail("IllegalArgumentException expected"); } catch (IllegalArgumentException e) { assertEquals("illegal local port: -1", e.getMessage()); } } private void tryStop(SshServer sshServer) { try { log.debug("stopping SSH server"); sshServer.stop(); } catch (IOException e) { log.error("Failed to stop SSH server", e); } } private String readLine(InputStream is) throws IOException { try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, TRANSFER_CHARSET))) { String line = reader.readLine(); assertNotNull(line); return line.trim(); } } private SshServer setUpSshServer() throws IOException { return setUpSshServer(KeyUtils.RSA_ALGORITHM); } private SshServer setUpSshServer(String algorithm) throws IOException { SshServer sshServer = SshServer.setUpDefaultServer(); sshServer.setPort(0); AbstractGeneratorHostKeyProvider hostKeyProvider = SecurityUtils.createGeneratorHostKeyProvider(getServerKeyFile(algorithm)); hostKeyProvider.setAlgorithm(algorithm); if (algorithm.equals(KeyUtils.EC_ALGORITHM)) { hostKeyProvider.setKeySize(256); } sshServer.setKeyPairProvider(hostKeyProvider); sshServer.setPublickeyAuthenticator(AcceptAllPublickeyAuthenticator.INSTANCE); sshServer.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE); writeFingerprintToKnownHosts(algorithm); sshServer.start(); int sshServerPort = sshServer.getPort(); assertTrue(sshServerPort > 0); return sshServer; } private void writeFingerprintToKnownHosts(String algorithm) throws IOException { switch (algorithm) { case KeyUtils.RSA_ALGORITHM: appendToSshFile(KNOWN_HOSTS_FILENAME, "localhost ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDL8360Wxcgo33sggS0bSid0u7Ad4XFig8/e0UfD5l02x/w2DRJuqJow4SiDfi9jvD8p3lu7To7b/oGH/c/vsK9j35ICG0eJ/bbnQDuHROBAnbAC6PXN+/XX2F9s48KlOC5dQXrGYyYhoozW67yoHTooisZSzF/iyPdNat64rM0+ZO3dV6eEQ0FItYO632YcSiBRE7YZe9rP7ne50xaltKgrAmHRDRo+tjIcykrlcZFG1Bp/ct9Ejs2DQDsFOZRCmFbag0pQxxbkA1U6z7O3qwhhDWcJz2ZHDHK8DUkgHdX+Hbp7LxBWEaCiU8cL+S6rmCpNsui9NT/XeoLuXQ4J8jX\n"); break; case KeyUtils.EC_ALGORITHM: appendToSshFile(KNOWN_HOSTS_FILENAME, "localhost ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBCH+0xjLYNGoqVGlD4VtKHF1Tig2/Y76BxVld88bYAaRV4ojJni62vIYMKqk+FMZhL1lcQ/VQTvIeLMnYk+grKo=\n"); break; default: throw new IllegalArgumentException("Unknown algorithm: " + algorithm); } } private static Path getServerKeyFile(String algorithm) { switch (algorithm) { case KeyUtils.RSA_ALGORITHM: return SERVER_RSA_KEY; case KeyUtils.EC_ALGORITHM: return SERVER_ECDSA_KEY; default: throw new IllegalArgumentException("Unknown algorithm: " + algorithm); } } private void appendToSshFile(String filename, String text) throws IOException { Path config = dotSsh.resolve(filename); Files.write(config, text.getBytes(CONFIG_CHARSET), StandardOpenOption.APPEND, StandardOpenOption.CREATE); } }