From a262f519008c71d90f17654561d220d591b2178d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henning=20P=C3=B6ttker?= Date: Thu, 21 Dec 2023 22:33:54 +0100 Subject: [PATCH] Implement OpenSSH strict key exchange extension (#917) --- .../com/hierynomus/sshj/SshdContainer.java | 3 +- .../transport/kex/StrictKeyExchangeTest.java | 111 ++++++++++++++++++ .../net/schmizz/sshj/transport/Converter.java | 8 ++ .../schmizz/sshj/transport/KeyExchanger.java | 34 +++++- .../net/schmizz/sshj/transport/Proposal.java | 9 +- .../schmizz/sshj/transport/TransportImpl.java | 19 ++- .../sshj/transport/KeyExchangeRepeatTest.java | 2 +- 7 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java diff --git a/src/itest/java/com/hierynomus/sshj/SshdContainer.java b/src/itest/java/com/hierynomus/sshj/SshdContainer.java index 82cfda644..91b531e68 100644 --- a/src/itest/java/com/hierynomus/sshj/SshdContainer.java +++ b/src/itest/java/com/hierynomus/sshj/SshdContainer.java @@ -146,8 +146,9 @@ public static Builder defaultBuilder() { .withFileFromString("sshd_config", sshdConfig.build()); } + @Override public void accept(@NotNull DockerfileBuilder builder) { - builder.from("alpine:3.18.3"); + builder.from("alpine:3.19.0"); builder.run("apk add --no-cache openssh"); builder.expose(22); builder.copy("entrypoint.sh", "/entrypoint.sh"); diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java new file mode 100644 index 000000000..2abe71a72 --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport.kex; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.hierynomus.sshj.SshdContainer; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +class StrictKeyExchangeTest { + + @Container + private static final SshdContainer sshd = new SshdContainer(); + + private final List watchedLoggers = new ArrayList<>(); + private final ListAppender logWatcher = new ListAppender<>(); + + @BeforeEach + void setUpLogWatcher() { + logWatcher.start(); + setUpLogger("net.schmizz.sshj.transport.Decoder"); + setUpLogger("net.schmizz.sshj.transport.Encoder"); + setUpLogger("net.schmizz.sshj.transport.KeyExchanger"); + } + + @AfterEach + void tearDown() { + watchedLoggers.forEach(Logger::detachAndStopAllAppenders); + } + + private void setUpLogger(String className) { + Logger logger = ((Logger) LoggerFactory.getLogger(className)); + logger.addAppender(logWatcher); + watchedLoggers.add(logger); + } + + @Test + void strictKeyExchange() throws Throwable { + try (SSHClient client = sshd.getConnectedClient()) { + client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); + assertTrue(client.isAuthenticated()); + } + List keyExchangerLogs = getLogs("KeyExchanger"); + assertThat(keyExchangerLogs).containsSequence( + "Initiating key exchange", + "Sending SSH_MSG_KEXINIT", + "Received SSH_MSG_KEXINIT", + "Enabling strict key exchange extension" + ); + List decoderLogs = getLogs("Decoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(decoderLogs).containsExactly( + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #3" + ); + List encoderLogs = getLogs("Encoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(encoderLogs).containsExactly( + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #3" + ); + } + + private List getLogs(String className) { + return logWatcher.list.stream() + .filter(event -> event.getLoggerName().endsWith(className)) + .map(ILoggingEvent::getFormattedMessage) + .collect(Collectors.toList()); + } + +} diff --git a/src/main/java/net/schmizz/sshj/transport/Converter.java b/src/main/java/net/schmizz/sshj/transport/Converter.java index 9d532f965..6f3431c3f 100644 --- a/src/main/java/net/schmizz/sshj/transport/Converter.java +++ b/src/main/java/net/schmizz/sshj/transport/Converter.java @@ -51,6 +51,14 @@ long getSequenceNumber() { return seq; } + void resetSequenceNumber() { + seq = -1; + } + + boolean isSequenceNumberAtMax() { + return seq == 0xffffffffL; + } + void setAlgorithms(Cipher cipher, MAC mac, Compression compression) { this.cipher = cipher; this.mac = mac; diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index b8979f7b5..abaf72e10 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -60,6 +60,10 @@ private static enum Expected { private final AtomicBoolean kexOngoing = new AtomicBoolean(); + private final AtomicBoolean initialKex = new AtomicBoolean(true); + + private final AtomicBoolean strictKex = new AtomicBoolean(); + /** What we are expecting from the next packet */ private Expected expected = Expected.KEXINIT; @@ -123,6 +127,14 @@ boolean isKexOngoing() { return kexOngoing.get(); } + boolean isStrictKex() { + return strictKex.get(); + } + + boolean isInitialKex() { + return initialKex.get(); + } + /** * Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily * after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if @@ -183,7 +195,7 @@ private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); List knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort()); - clientProposal = new Proposal(transport.getConfig(), knownHostAlgs); + clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get()); transport.write(clientProposal.getPacket()); kexInitSent.set(); } @@ -202,6 +214,9 @@ private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); transport.write(new SSHPacket(Message.NEWKEYS)); + if (strictKex.get()) { + transport.getEncoder().resetSequenceNumber(); + } } /** @@ -234,6 +249,10 @@ private synchronized void verifyHost(PublicKey key) private void setKexDone() { kexOngoing.set(false); + initialKex.set(false); + if (strictKex.get()) { + transport.getDecoder().resetSequenceNumber(); + } kexInitSent.clear(); done.set(); } @@ -242,6 +261,7 @@ private void gotKexInit(SSHPacket buf) throws TransportException { buf.rpos(buf.rpos() - 1); final Proposal serverProposal = new Proposal(buf); + gotStrictKexInfo(serverProposal); negotiatedAlgs = clientProposal.negotiate(serverProposal); log.debug("Negotiated algorithms: {}", negotiatedAlgs); for(AlgorithmsVerifier v: algorithmVerifiers) { @@ -265,6 +285,18 @@ private void gotKexInit(SSHPacket buf) } } + private void gotStrictKexInfo(Proposal serverProposal) throws TransportException { + if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) { + strictKex.set(true); + log.debug("Enabling strict key exchange extension"); + if (transport.getDecoder().getSequenceNumber() != 0) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "SSH_MSG_KEXINIT was not first package during strict key exchange" + ); + } + } + } + /** * Private method used while putting new keys into use that will resize the key used to initialize the cipher to the * needed length. diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index 5f5f8a1f2..3a4102dd3 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -37,8 +37,11 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config, List knownHostAlgs) { + public Proposal(Config config, List knownHostAlgs, boolean initialKex) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); + if (initialKex) { + kex.add("kex-strict-c-v00@openssh.com"); + } sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs); c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); @@ -91,6 +94,10 @@ public List getKeyExchangeAlgorithms() { return kex; } + public boolean isStrictKeyExchangeSupportedByServer() { + return kex.contains("kex-strict-s-v00@openssh.com"); + } + public List getHostKeyAlgorithms() { return sig; } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 58107c5b8..1cd6cb2b2 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -426,7 +426,7 @@ public long write(SSHPacket payload) assert m != Message.KEXINIT; kexer.waitForDone(); } - } else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet + } else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet kexer.startKex(true); final long seq = encoder.encode(payload); @@ -479,9 +479,20 @@ public void handle(Message msg, SSHPacket buf) log.trace("Received packet {}", msg); + if (kexer.isInitialKex()) { + if (decoder.isSequenceNumberAtMax()) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Sequence number of decoder is about to wrap during initial key exchange"); + } + if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Unexpected packet type during initial strict key exchange"); + } + } + if (msg.geq(50)) { // not a transport layer packet service.handle(msg, buf); - } else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet + } else if (isKexerPacket(msg)) { kexer.handle(msg, buf); } else { switch (msg) { @@ -513,6 +524,10 @@ public void handle(Message msg, SSHPacket buf) } } + private static boolean isKexerPacket(Message msg) { + return msg.in(20, 21) || msg.in(30, 49); + } + private void gotDebug(SSHPacket buf) throws TransportException { try { diff --git a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java index c1f8655ab..e6160e701 100644 --- a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java +++ b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java @@ -112,7 +112,7 @@ private void performAndCheckKeyExchange() throws TransportException { } private SSHPacket getKexinitPacket() { - SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket(); + SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket(); kexinitPacket.rpos(kexinitPacket.rpos() + 1); return kexinitPacket; }