From 1c547886c8cfe2b2f38a56ef124fa533d26a45ee Mon Sep 17 00:00:00 2001 From: Raul Santelices Date: Tue, 21 Nov 2023 15:21:35 -0500 Subject: [PATCH 01/13] Fix for Remote port forwarding buffers can grow without limits (issue #658) (#913) * Fix for Remote port forwarding buffers can grow without limits (issue #658) * Update test classes to use JUnit 5 * Fix MB computation --- src/main/java/net/schmizz/sshj/Config.java | 4 + .../java/net/schmizz/sshj/ConfigImpl.java | 12 + .../schmizz/sshj/common/CircularBuffer.java | 194 +++++++++++++++ .../connection/channel/AbstractChannel.java | 7 +- .../channel/ChannelInputStream.java | 55 +++-- .../channel/direct/SessionChannel.java | 2 +- .../forwarded/RemotePFPerformanceTest.java | 188 +++++++++++++++ .../sshj/common/CircularBufferTest.java | 221 ++++++++++++++++++ 8 files changed, 649 insertions(+), 34 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/common/CircularBuffer.java create mode 100644 src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java create mode 100644 src/test/java/net/schmizz/sshj/common/CircularBufferTest.java diff --git a/src/main/java/net/schmizz/sshj/Config.java b/src/main/java/net/schmizz/sshj/Config.java index dfb6c1229..24166d08b 100644 --- a/src/main/java/net/schmizz/sshj/Config.java +++ b/src/main/java/net/schmizz/sshj/Config.java @@ -200,4 +200,8 @@ public interface Config { * See {@link #isVerifyHostKeyCertificates()}. */ void setVerifyHostKeyCertificates(boolean value); + + int getMaxCircularBufferSize(); + + void setMaxCircularBufferSize(int maxCircularBufferSize); } diff --git a/src/main/java/net/schmizz/sshj/ConfigImpl.java b/src/main/java/net/schmizz/sshj/ConfigImpl.java index 67243cb3d..23ad79bfd 100644 --- a/src/main/java/net/schmizz/sshj/ConfigImpl.java +++ b/src/main/java/net/schmizz/sshj/ConfigImpl.java @@ -49,6 +49,8 @@ public class ConfigImpl private boolean waitForServerIdentBeforeSendingClientIdent = false; private LoggerFactory loggerFactory; private boolean verifyHostKeyCertificates = true; + // HF-982: default to 16MB buffers. + private int maxCircularBufferSize = 16 * 1024 * 1024; @Override public List> getCipherFactories() { @@ -175,6 +177,16 @@ public LoggerFactory getLoggerFactory() { return loggerFactory; } + @Override + public int getMaxCircularBufferSize() { + return maxCircularBufferSize; + } + + @Override + public void setMaxCircularBufferSize(int maxCircularBufferSize) { + this.maxCircularBufferSize = maxCircularBufferSize; + } + @Override public void setLoggerFactory(LoggerFactory loggerFactory) { this.loggerFactory = loggerFactory; diff --git a/src/main/java/net/schmizz/sshj/common/CircularBuffer.java b/src/main/java/net/schmizz/sshj/common/CircularBuffer.java new file mode 100644 index 000000000..ea47351ea --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/CircularBuffer.java @@ -0,0 +1,194 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +public class CircularBuffer> { + + public static class CircularBufferException + extends SSHException { + + public CircularBufferException(String message) { + super(message); + } + } + + public static final class PlainCircularBuffer + extends CircularBuffer { + + public PlainCircularBuffer(int size, int maxSize) { + super(size, maxSize); + } + } + + /** + * Maximum size of the internal array (one plus the maximum capacity of the buffer). + */ + private final int maxSize; + /** + * Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos. + */ + private byte[] data; + /** + * Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty. + * Can take the value data.length, which is equivalent to 0. + */ + private int rpos; + /** + * Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is + * empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to + * data.length - 1 bytes. Can take the value data.length, which is equivalent to 0. + */ + private int wpos; + + /** + * Determines the size to which to grow the internal array. + */ + private int getNextSize(int currentSize) { + // Use next power of 2. + int nextSize = 1; + while (nextSize < currentSize) { + nextSize <<= 1; + if (nextSize <= 0) { + return maxSize; + } + } + return Math.min(nextSize, maxSize); // limit to max size + } + + /** + * Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/ + */ + public CircularBuffer(int size, int maxSize) { + this.maxSize = maxSize; + if (size > maxSize) { + throw new IllegalArgumentException( + String.format("Initial requested size %d larger than maximum size %d", size, maxSize)); + } + int initialSize = getNextSize(size); + this.data = new byte[initialSize]; + this.rpos = 0; + this.wpos = 0; + } + + /** + * Data available in the buffer for reading. + */ + public int available() { + int available = wpos - rpos; + return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos + } + + private void ensureAvailable(int a) + throws CircularBufferException { + if (available() < a) { + throw new CircularBufferException("Underflow"); + } + } + + /** + * Returns how many more bytes this buffer can receive. + */ + public int maxPossibleRemainingCapacity() { + // Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left. + int remaining = rpos - wpos - 1; + if (remaining < 0) { + remaining += data.length; // adjust if rpos is left of wpos + } + // Add the maximum amount the internal array can grow. + return remaining + maxSize - data.length; + } + + /** + * If the internal array does not have room for "capacity" more bytes, resizes the array to make that room. + */ + void ensureCapacity(int capacity) throws CircularBufferException { + int available = available(); + int remaining = data.length - available; + // If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left. + if (remaining <= capacity) { + int neededSize = available + capacity + 1; + int nextSize = getNextSize(neededSize); + if (nextSize < neededSize) { + throw new CircularBufferException("Attempted overflow"); + } + byte[] tmp = new byte[nextSize]; + // Copy data to the beginning of the new array. + if (wpos >= rpos) { + System.arraycopy(data, rpos, tmp, 0, available); + wpos -= rpos; // wpos must be relative to the new rpos, which will be 0 + } else { + int tail = data.length - rpos; + System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos + System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos + wpos += tail; // wpos must be relative to the new rpos, which will be 0 + } + rpos = 0; + data = tmp; + } + } + + /** + * Reads data from this buffer into the provided array. + */ + public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException { + ensureAvailable(length); + + int rposNext = rpos + length; + if (rposNext <= data.length) { + System.arraycopy(data, rpos, destination, offset, length); + } else { + int tail = data.length - rpos; + System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos + rposNext = length - tail; // rpos wraps around the end of the buffer + System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder + } + // This can make rpos equal data.length, which has the same effect as wpos being 0. + rpos = rposNext; + } + + /** + * Writes data to this buffer from the provided array. + */ + @SuppressWarnings("unchecked") + public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException { + ensureCapacity(length); + + int wposNext = wpos + length; + if (wposNext <= data.length) { + System.arraycopy(source, offset, data, wpos, length); + } else { + int tail = data.length - wpos; + System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos + wposNext = length - tail; // wpos wraps around the end of the buffer + System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder + } + // This can make wpos equal data.length, which has the same effect as wpos being 0. + wpos = wposNext; + + return (T) this; + } + + // Used only for testing. + int length() { + return data.length; + } + + @Override + public String toString() { + return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]"; + } + +} diff --git a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java index cb2373439..0ea8fae45 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java @@ -164,8 +164,7 @@ public String getType() { } @Override - public void handle(Message msg, SSHPacket buf) - throws ConnectionException, TransportException { + public void handle(Message msg, SSHPacket buf) throws SSHException { switch (msg) { case CHANNEL_DATA: @@ -354,7 +353,7 @@ protected void finishOff() { } protected void gotExtendedData(SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Extended data not supported on " + type + " channel"); } @@ -375,7 +374,7 @@ protected SSHPacket newBuffer(Message cmd) { } protected void receiveInto(ChannelInputStream stream, SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { final int len; try { len = buf.readUInt32AsInt(); diff --git a/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java b/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java index ee03d23cd..530f0167d 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/ChannelInputStream.java @@ -38,7 +38,7 @@ public final class ChannelInputStream private final Channel chan; private final Transport trans; private final Window.Local win; - private final Buffer.PlainBuffer buf; + private final CircularBuffer.PlainCircularBuffer buf; private final byte[] b = new byte[1]; private boolean eof; @@ -46,10 +46,11 @@ public final class ChannelInputStream public ChannelInputStream(Channel chan, Transport trans, Window.Local win) { this.chan = chan; - log = chan.getLoggerFactory().getLogger(getClass()); + this.log = chan.getLoggerFactory().getLogger(getClass()); this.trans = trans; this.win = win; - buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize()); + this.buf = new CircularBuffer.PlainCircularBuffer( + chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize()); } @Override @@ -113,48 +114,44 @@ public int read(byte[] b, int off, int len) len = buf.available(); } buf.readRawBytes(b, off, len); - if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) { - buf.clear(); - } - } - if (!chan.getAutoExpand()) { - checkWindow(); + if (!chan.getAutoExpand()) { + checkWindow(); + } } return len; } - public void receive(byte[] data, int offset, int len) - throws ConnectionException, TransportException { + public void receive(byte[] data, int offset, int len) throws SSHException { if (eof) { throw new ConnectionException("Getting data on EOF'ed stream"); } synchronized (buf) { buf.putRawBytes(data, offset, len); buf.notifyAll(); - } - // Potential fix for #203 (window consumed below 0). - // This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST - // And the window has not expanded yet. - synchronized (win) { + // Potential fix for #203 (window consumed below 0). + // This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST + // And the window has not expanded yet. win.consume(len); - } - if (chan.getAutoExpand()) { - checkWindow(); + if (chan.getAutoExpand()) { + checkWindow(); + } } } - private void checkWindow() - throws TransportException { - synchronized (win) { - final long adjustment = win.neededAdjustment(); - if (adjustment > 0) { - log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment); - trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST) - .putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment)); - win.expand(adjustment); - } + private void checkWindow() throws TransportException { + /* + * Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The + * difference between that and the remaining capacity is the maximum adjustment we can make to the window. + */ + final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize(); + final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment); + if (adjustment > 0) { + log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment); + trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST) + .putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment)); + win.expand(adjustment); } } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java index dfdfa55e4..873958204 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java @@ -210,7 +210,7 @@ protected void eofInputStreams() { @Override protected void gotExtendedData(SSHPacket buf) - throws ConnectionException, TransportException { + throws SSHException { try { final int dataTypeCode = buf.readUInt32AsInt(); if (dataTypeCode == 1) diff --git a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java new file mode 100644 index 000000000..448fc09e8 --- /dev/null +++ b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePFPerformanceTest.java @@ -0,0 +1,188 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.connection.channel.forwarded; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward; +import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RemotePFPerformanceTest { + + private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class); + + @Test + @Disabled + public void startPF() throws IOException, InterruptedException { + DefaultConfig config = new DefaultConfig(); + config.setMaxCircularBufferSize(16 * 1024 * 1024); + SSHClient client = new SSHClient(config); + client.loadKnownHosts(); + client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba"); + + client.getConnection().getKeepAlive().setKeepAliveInterval(5); + client.connect("localhost"); + client.getConnection().getKeepAlive().setKeepAliveInterval(5); + + Object consumerReadyMonitor = new Object(); + ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor); + ProducerThread producerThread = new ProducerThread(); + try { + + client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD")); + + /* + * We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further + * forward all such channels to google.com:80 + */ + client.getRemotePortForwarder().bind( + // where the server should listen + new Forward(8888), + // what we do with incoming connections that are forwarded to us + new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345))); + + consumerThread.start(); + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.wait(); + } + producerThread.start(); + + // Wait for consumer to finish receiving data. + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.wait(); + } + + } finally { + producerThread.interrupt(); + consumerThread.interrupt(); + client.disconnect(); + } + } + + private static class ConsumerThread extends Thread { + private final Object consumerReadyMonitor; + + private ConsumerThread(Object consumerReadyMonitor) { + super("Consumer"); + this.consumerReadyMonitor = consumerReadyMonitor; + } + + @Override + public void run() { + try (ServerSocket serverSocket = new ServerSocket(12345)) { + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + try (Socket acceptedSocket = serverSocket.accept()) { + InputStream in = acceptedSocket.getInputStream(); + int numRead; + byte[] buf = new byte[40000]; + //byte[] buf = new byte[255 * 4 * 1000]; + byte expectedNext = 1; + while ((numRead = in.read(buf)) != 0) { + if (Thread.interrupted()) { + log.info("Consumer thread interrupted"); + return; + } + log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1])); + if (buf[numRead - 1] == 0) { + verifyData(buf, numRead - 1, expectedNext); + break; + } + expectedNext = verifyData(buf, numRead, expectedNext); + // Slow down consumer to test buffering. + Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS"))); + } + log.info("Consumer read end of stream value: " + numRead); + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + } + } catch (Exception e) { + synchronized (consumerReadyMonitor) { + consumerReadyMonitor.notifyAll(); + } + e.printStackTrace(); + } + } + + private byte verifyData(byte[] buf, int numRead, byte expectedNext) { + for (int i = 0; i < numRead; ++i) { + if (buf[i] != expectedNext) { + fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext); + } + if (++expectedNext == 0) { + expectedNext = 1; + } + } + return expectedNext; + } + } + + private static class ProducerThread extends Thread { + private ProducerThread() { + super("Producer"); + } + + @Override + public void run() { + try (Socket clientSocket = new Socket("127.0.0.1", 8888); + OutputStream writer = clientSocket.getOutputStream()) { + byte[] buf = getData(); + assertEquals(buf[0], 1); + assertEquals(buf[buf.length - 1], -1); + for (int i = 0; i < 1000; ++i) { + writer.write(buf); + if (Thread.interrupted()) { + log.info("Consumer thread interrupted"); + return; + } + log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1])); + } + writer.write(0); // end of stream value + log.info("Producer finished sending data"); + } catch (Exception e) { + e.printStackTrace(); + } + } + + private byte[] getData() { + byte[] buf = new byte[255 * 4 * 1000]; + byte nextValue = 1; + for (int i = 0; i < buf.length; ++i) { + buf[i] = nextValue++; + // reserve 0 for end of stream + if (nextValue == 0) { + nextValue = 1; + } + } + return buf; + } + } + +} diff --git a/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java b/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java new file mode 100644 index 000000000..da53afc38 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/common/CircularBufferTest.java @@ -0,0 +1,221 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import static org.junit.jupiter.api.Assertions.*; + +import net.schmizz.sshj.common.CircularBuffer.CircularBufferException; +import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer; +import org.junit.jupiter.api.Test; + +public class CircularBufferTest { + + @Test + public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(500); + buffer.putRawBytes(dataToWrite, 0, 100); + buffer.putRawBytes(dataToWrite, 100, 100); + + byte[] dataToRead = new byte[500]; + buffer.readRawBytes(dataToRead, 0, 80); + buffer.readRawBytes(dataToRead, 80, 80); + + buffer.putRawBytes(dataToWrite, 200, 100); + buffer.readRawBytes(dataToRead, 160, 80); + + buffer.putRawBytes(dataToWrite, 300, 100); + buffer.readRawBytes(dataToRead, 240, 80); + + buffer.putRawBytes(dataToWrite, 400, 100); + buffer.readRawBytes(dataToRead, 320, 80); + buffer.readRawBytes(dataToRead, 400, 100); + + assertEquals(256, buffer.length()); + assertArrayEquals(dataToWrite, dataToRead); + } + + @Test + public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(500); + buffer.putRawBytes(dataToWrite, 0, 100); + buffer.putRawBytes(dataToWrite, 100, 100); + + byte[] dataToRead = new byte[500]; + buffer.readRawBytes(dataToRead, 0, 80); + buffer.readRawBytes(dataToRead, 80, 80); + + buffer.putRawBytes(dataToWrite, 200, 100); + buffer.readRawBytes(dataToRead, 160, 80); + + buffer.putRawBytes(dataToWrite, 300, 100); + buffer.readRawBytes(dataToRead, 240, 80); + + buffer.putRawBytes(dataToWrite, 400, 100); + buffer.readRawBytes(dataToRead, 320, 80); + + buffer.readRawBytes(dataToRead, 400, 100); + + assertEquals(256, buffer.length()); + assertArrayEquals(dataToWrite, dataToRead); + } + + @Test + public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + byte[] dataToWrite = getData(64); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end + + assertEquals(64, buffer.available()); + assertEquals(64 * 2, buffer.length()); + } + + @Test + public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + + // Move 1 byte forward. + buffer.putRawBytes(new byte[1], 0, 1); + buffer.readRawBytes(new byte[1], 0, 1); + + // Force writes to wrap around. + byte[] dataToWrite = getData(64); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end + + assertEquals(64, buffer.available()); + assertEquals(64 * 2, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + // Max capacity is always one less than the buffer size. + int maxCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1, maxCapacity); + + byte[] dataToWrite = getData(maxCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + final int initiallyWritten = 10; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length + initiallyWritten, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + // Cause the internal write pointer to wrap around and be left of the read pointer. + final int initiallyWritten = 40; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity); + buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); + + assertEquals(dataToWrite.length + initiallyWritten, buffer.available()); + assertEquals(64, buffer.length()); + } + + @Test + public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64); + + final int initiallyWritten = 10; + buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten); + + // Max remaining capacity is always one less than the remaining buffer size. + int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity(); + assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity); + + byte[] dataToWrite = getData(maxRemainingCapacity + 1); + assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length)); + } + + @Test + public void shouldThrowWhenReadingEmptyBuffer() { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1)); + } + + @Test + public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException { + PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE); + buffer.putRawBytes(new byte[1], 0, 1); + assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2)); + } + + @Test + public void shouldThrowOnAboveMaximumInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64)); + } + + @Test + public void shouldThrowOnMaximumInitialSize() { + assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64)); + } + + @Test + public void shouldAllowFullCapacity() throws CircularBufferException { + int maxSize = 1024; + PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize); + buffer.ensureCapacity(maxSize - 1); + assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity()); + } + + @Test + public void shouldThrowOnTooLargeRequestedCapacity() { + int maxSize = 1024; + PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize); + assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize)); + } + + private static byte[] getData(int length) { + byte[] data = new byte[length]; + byte nextValue = 0; + for (int i = 0; i < length; ++i) { + data[i] = nextValue++; + } + return data; + } +} From 50c753dc5801612bec33ff6fa205c3d8a17d854a Mon Sep 17 00:00:00 2001 From: David Kocher Date: Tue, 21 Nov 2023 21:24:28 +0100 Subject: [PATCH 02/13] Fixed writing known hosts key string (#903) * Fix #902. * Add test. --- .../sshj/transport/verification/OpenSSHKnownHosts.java | 4 +++- .../sshj/transport/verification/OpenSSHKnownHostsTest.java | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index 02ef7de41..7d71e1aa1 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -41,6 +41,7 @@ import java.security.PublicKey; import java.security.spec.RSAPublicKeySpec; import java.util.ArrayList; +import java.util.Arrays; import java.util.Base64; import java.util.List; @@ -468,7 +469,8 @@ public String getLine() { } private String getKeyString(PublicKey pk) { - return Base64.getEncoder().encodeToString(pk.getEncoded()); + final Buffer.PlainBuffer buf = new Buffer.PlainBuffer().putPublicKey(pk); + return Base64.getEncoder().encodeToString(Arrays.copyOfRange(buf.array(), buf.rpos(), buf.available())); } protected String getHostPart() { diff --git a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java index d6c2f6056..e509656ea 100644 --- a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java +++ b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java @@ -63,6 +63,11 @@ public void shouldParseAndVerifyHashedHostEntry() throws Exception { OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts); assertTrue(ohk.verify("192.168.1.61", 22, k)); assertFalse(ohk.verify("192.168.1.2", 22, k)); + ohk.write(); + for (OpenSSHKnownHosts.KnownHostEntry entry : ohk.entries()) { + assertEquals("|1|F1E1KeoE/eEWhi10WpGv4OdiO6Y=|3988QV0VE8wmZL7suNrYQLITLCg= ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA6P9Hlwdahh250jGZYKg2snRq2j2lFJVdKSHyxqbJiVy9VX9gTkN3K2MD48qyrYLYOyGs3vTttyUk+cK++JMzURWsrP4piby7LpeOT+3Iq8CQNj4gXZdcH9w15Vuk2qS11at6IsQPVHpKD9HGg9//EFUccI/4w06k4XXLm/IxOGUwj6I2AeWmEOL3aDi+fe07TTosSdLUD6INtR0cyKsg0zC7Da24ixoShT8Oy3x2MpR7CY3PQ1pUVmvPkr79VeA+4qV9F1JM09WdboAMZgWQZ+XrbtuBlGsyhpUHSCQOya+kOJ+bYryS+U7A+6nmTW3C9FX4FgFqTF89UHOC7V0zZQ==", + entry.getLine()); + } } @Test From a262f519008c71d90f17654561d220d591b2178d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henning=20P=C3=B6ttker?= Date: Thu, 21 Dec 2023 22:33:54 +0100 Subject: [PATCH 03/13] Implement OpenSSH strict key exchange extension (#917) --- .../com/hierynomus/sshj/SshdContainer.java | 3 +- .../transport/kex/StrictKeyExchangeTest.java | 111 ++++++++++++++++++ .../net/schmizz/sshj/transport/Converter.java | 8 ++ .../schmizz/sshj/transport/KeyExchanger.java | 34 +++++- .../net/schmizz/sshj/transport/Proposal.java | 9 +- .../schmizz/sshj/transport/TransportImpl.java | 19 ++- .../sshj/transport/KeyExchangeRepeatTest.java | 2 +- 7 files changed, 180 insertions(+), 6 deletions(-) create mode 100644 src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java diff --git a/src/itest/java/com/hierynomus/sshj/SshdContainer.java b/src/itest/java/com/hierynomus/sshj/SshdContainer.java index 82cfda644..91b531e68 100644 --- a/src/itest/java/com/hierynomus/sshj/SshdContainer.java +++ b/src/itest/java/com/hierynomus/sshj/SshdContainer.java @@ -146,8 +146,9 @@ public static Builder defaultBuilder() { .withFileFromString("sshd_config", sshdConfig.build()); } + @Override public void accept(@NotNull DockerfileBuilder builder) { - builder.from("alpine:3.18.3"); + builder.from("alpine:3.19.0"); builder.run("apk add --no-cache openssh"); builder.expose(22); builder.copy("entrypoint.sh", "/entrypoint.sh"); diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java new file mode 100644 index 000000000..2abe71a72 --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport.kex; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.hierynomus.sshj.SshdContainer; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +class StrictKeyExchangeTest { + + @Container + private static final SshdContainer sshd = new SshdContainer(); + + private final List watchedLoggers = new ArrayList<>(); + private final ListAppender logWatcher = new ListAppender<>(); + + @BeforeEach + void setUpLogWatcher() { + logWatcher.start(); + setUpLogger("net.schmizz.sshj.transport.Decoder"); + setUpLogger("net.schmizz.sshj.transport.Encoder"); + setUpLogger("net.schmizz.sshj.transport.KeyExchanger"); + } + + @AfterEach + void tearDown() { + watchedLoggers.forEach(Logger::detachAndStopAllAppenders); + } + + private void setUpLogger(String className) { + Logger logger = ((Logger) LoggerFactory.getLogger(className)); + logger.addAppender(logWatcher); + watchedLoggers.add(logger); + } + + @Test + void strictKeyExchange() throws Throwable { + try (SSHClient client = sshd.getConnectedClient()) { + client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); + assertTrue(client.isAuthenticated()); + } + List keyExchangerLogs = getLogs("KeyExchanger"); + assertThat(keyExchangerLogs).containsSequence( + "Initiating key exchange", + "Sending SSH_MSG_KEXINIT", + "Received SSH_MSG_KEXINIT", + "Enabling strict key exchange extension" + ); + List decoderLogs = getLogs("Decoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(decoderLogs).containsExactly( + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #0", + "Received packet #1", + "Received packet #2", + "Received packet #3" + ); + List encoderLogs = getLogs("Encoder").stream() + .map(log -> log.split(":")[0]) + .collect(Collectors.toList()); + assertThat(encoderLogs).containsExactly( + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #0", + "Encoding packet #1", + "Encoding packet #2", + "Encoding packet #3" + ); + } + + private List getLogs(String className) { + return logWatcher.list.stream() + .filter(event -> event.getLoggerName().endsWith(className)) + .map(ILoggingEvent::getFormattedMessage) + .collect(Collectors.toList()); + } + +} diff --git a/src/main/java/net/schmizz/sshj/transport/Converter.java b/src/main/java/net/schmizz/sshj/transport/Converter.java index 9d532f965..6f3431c3f 100644 --- a/src/main/java/net/schmizz/sshj/transport/Converter.java +++ b/src/main/java/net/schmizz/sshj/transport/Converter.java @@ -51,6 +51,14 @@ long getSequenceNumber() { return seq; } + void resetSequenceNumber() { + seq = -1; + } + + boolean isSequenceNumberAtMax() { + return seq == 0xffffffffL; + } + void setAlgorithms(Cipher cipher, MAC mac, Compression compression) { this.cipher = cipher; this.mac = mac; diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index b8979f7b5..abaf72e10 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -60,6 +60,10 @@ private static enum Expected { private final AtomicBoolean kexOngoing = new AtomicBoolean(); + private final AtomicBoolean initialKex = new AtomicBoolean(true); + + private final AtomicBoolean strictKex = new AtomicBoolean(); + /** What we are expecting from the next packet */ private Expected expected = Expected.KEXINIT; @@ -123,6 +127,14 @@ boolean isKexOngoing() { return kexOngoing.get(); } + boolean isStrictKex() { + return strictKex.get(); + } + + boolean isInitialKex() { + return initialKex.get(); + } + /** * Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily * after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if @@ -183,7 +195,7 @@ private void sendKexInit() throws TransportException { log.debug("Sending SSH_MSG_KEXINIT"); List knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort()); - clientProposal = new Proposal(transport.getConfig(), knownHostAlgs); + clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get()); transport.write(clientProposal.getPacket()); kexInitSent.set(); } @@ -202,6 +214,9 @@ private void sendNewKeys() throws TransportException { log.debug("Sending SSH_MSG_NEWKEYS"); transport.write(new SSHPacket(Message.NEWKEYS)); + if (strictKex.get()) { + transport.getEncoder().resetSequenceNumber(); + } } /** @@ -234,6 +249,10 @@ private synchronized void verifyHost(PublicKey key) private void setKexDone() { kexOngoing.set(false); + initialKex.set(false); + if (strictKex.get()) { + transport.getDecoder().resetSequenceNumber(); + } kexInitSent.clear(); done.set(); } @@ -242,6 +261,7 @@ private void gotKexInit(SSHPacket buf) throws TransportException { buf.rpos(buf.rpos() - 1); final Proposal serverProposal = new Proposal(buf); + gotStrictKexInfo(serverProposal); negotiatedAlgs = clientProposal.negotiate(serverProposal); log.debug("Negotiated algorithms: {}", negotiatedAlgs); for(AlgorithmsVerifier v: algorithmVerifiers) { @@ -265,6 +285,18 @@ private void gotKexInit(SSHPacket buf) } } + private void gotStrictKexInfo(Proposal serverProposal) throws TransportException { + if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) { + strictKex.set(true); + log.debug("Enabling strict key exchange extension"); + if (transport.getDecoder().getSequenceNumber() != 0) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "SSH_MSG_KEXINIT was not first package during strict key exchange" + ); + } + } + } + /** * Private method used while putting new keys into use that will resize the key used to initialize the cipher to the * needed length. diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index 5f5f8a1f2..3a4102dd3 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -37,8 +37,11 @@ class Proposal { private final List s2cComp; private final SSHPacket packet; - public Proposal(Config config, List knownHostAlgs) { + public Proposal(Config config, List knownHostAlgs, boolean initialKex) { kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories()); + if (initialKex) { + kex.add("kex-strict-c-v00@openssh.com"); + } sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs); c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories()); c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories()); @@ -91,6 +94,10 @@ public List getKeyExchangeAlgorithms() { return kex; } + public boolean isStrictKeyExchangeSupportedByServer() { + return kex.contains("kex-strict-s-v00@openssh.com"); + } + public List getHostKeyAlgorithms() { return sig; } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 58107c5b8..1cd6cb2b2 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -426,7 +426,7 @@ public long write(SSHPacket payload) assert m != Message.KEXINIT; kexer.waitForDone(); } - } else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet + } else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet kexer.startKex(true); final long seq = encoder.encode(payload); @@ -479,9 +479,20 @@ public void handle(Message msg, SSHPacket buf) log.trace("Received packet {}", msg); + if (kexer.isInitialKex()) { + if (decoder.isSequenceNumberAtMax()) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Sequence number of decoder is about to wrap during initial key exchange"); + } + if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Unexpected packet type during initial strict key exchange"); + } + } + if (msg.geq(50)) { // not a transport layer packet service.handle(msg, buf); - } else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet + } else if (isKexerPacket(msg)) { kexer.handle(msg, buf); } else { switch (msg) { @@ -513,6 +524,10 @@ public void handle(Message msg, SSHPacket buf) } } + private static boolean isKexerPacket(Message msg) { + return msg.in(20, 21) || msg.in(30, 49); + } + private void gotDebug(SSHPacket buf) throws TransportException { try { diff --git a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java index c1f8655ab..e6160e701 100644 --- a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java +++ b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java @@ -112,7 +112,7 @@ private void performAndCheckKeyExchange() throws TransportException { } private SSHPacket getKexinitPacket() { - SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket(); + SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket(); kexinitPacket.rpos(kexinitPacket.rpos() + 1); return kexinitPacket; } From 81e87a4d3560521f9fb77e27574b2c796d39bc70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henning=20P=C3=B6ttker?= Date: Sat, 23 Dec 2023 10:26:29 +0100 Subject: [PATCH 04/13] Add unit tests of strict key exchange extension (#918) --- .../KeyExchangerStrictKeyExchangeTest.java | 236 ++++++++++++++++++ .../TransportImplStrictKeyExchangeTest.java | 120 +++++++++ 2 files changed, 356 insertions(+) create mode 100644 src/test/java/net/schmizz/sshj/transport/KeyExchangerStrictKeyExchangeTest.java create mode 100644 src/test/java/net/schmizz/sshj/transport/TransportImplStrictKeyExchangeTest.java diff --git a/src/test/java/net/schmizz/sshj/transport/KeyExchangerStrictKeyExchangeTest.java b/src/test/java/net/schmizz/sshj/transport/KeyExchangerStrictKeyExchangeTest.java new file mode 100644 index 000000000..c91834bf1 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/KeyExchangerStrictKeyExchangeTest.java @@ -0,0 +1,236 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; + +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.common.Factory; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.transport.kex.KeyExchange; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KeyExchangerStrictKeyExchangeTest { + + private TransportImpl transport; + private DefaultConfig config; + private KeyExchanger keyExchanger; + + @BeforeEach + void setUp() throws Exception { + KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS); + transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS); + config = new DefaultConfig() { + @Override + protected void initKeyExchangeFactories() { + setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() { + @Override + public KeyExchange create() { + return kex; + } + + @Override + public String getName() { + return "mock-kex"; + } + })); + } + }; + when(transport.getConfig()).thenReturn(config); + when(transport.getServerID()).thenReturn("some server id"); + when(transport.getClientID()).thenReturn("some client id"); + when(kex.next(any(), any())).thenReturn(true); + when(kex.getH()).thenReturn(new byte[0]); + when(kex.getK()).thenReturn(BigInteger.ZERO); + when(kex.getHash().digest()).thenReturn(new byte[10]); + + keyExchanger = new KeyExchanger(transport); + keyExchanger.addHostKeyVerifier(new PromiscuousVerifier()); + } + + @Test + void initialConditions() { + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isFalse(); + assertThat(keyExchanger.isStrictKex()).isFalse(); + assertThat(keyExchanger.isInitialKex()).isTrue(); + } + + @Test + void startInitialKex() throws Exception { + ArgumentCaptor sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class); + when(transport.write(sshPacketCaptor.capture())).thenReturn(0L); + + keyExchanger.startKex(false); + + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isTrue(); + assertThat(keyExchanger.isStrictKex()).isFalse(); + assertThat(keyExchanger.isInitialKex()).isTrue(); + + SSHPacket sshPacket = sshPacketCaptor.getValue(); + List kex = new Proposal(sshPacket).getKeyExchangeAlgorithms(); + assertThat(kex).endsWith("kex-strict-c-v00@openssh.com"); + } + + @Test + void receiveKexInitWithoutServerFlag() throws Exception { + keyExchanger.startKex(false); + + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false)); + + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isTrue(); + assertThat(keyExchanger.isStrictKex()).isFalse(); + assertThat(keyExchanger.isInitialKex()).isTrue(); + } + + @Test + void finishNonStrictKex() throws Exception { + keyExchanger.startKex(false); + + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false)); + keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31)); + keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS)); + + assertThat(keyExchanger.isKexDone()).isTrue(); + assertThat(keyExchanger.isKexOngoing()).isFalse(); + assertThat(keyExchanger.isStrictKex()).isFalse(); + assertThat(keyExchanger.isInitialKex()).isFalse(); + + verify(transport.getEncoder(), never()).resetSequenceNumber(); + verify(transport.getDecoder(), never()).resetSequenceNumber(); + } + + @Test + void receiveKexInitWithServerFlag() throws Exception { + keyExchanger.startKex(false); + + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true)); + + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isTrue(); + assertThat(keyExchanger.isStrictKex()).isTrue(); + assertThat(keyExchanger.isInitialKex()).isTrue(); + } + + @Test + void strictKexInitIsNotFirstPacket() throws Exception { + when(transport.getDecoder().getSequenceNumber()).thenReturn(1L); + keyExchanger.startKex(false); + + assertThatExceptionOfType(TransportException.class).isThrownBy( + () -> keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true)) + ).satisfies(e -> { + assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED); + assertThat(e.getMessage()).isEqualTo("SSH_MSG_KEXINIT was not first package during strict key exchange"); + }); + } + + @Test + void finishStrictKex() throws Exception { + keyExchanger.startKex(false); + + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true)); + verify(transport.getEncoder(), never()).resetSequenceNumber(); + keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31)); + verify(transport.getEncoder()).resetSequenceNumber(); + verify(transport.getDecoder(), never()).resetSequenceNumber(); + keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS)); + verify(transport.getDecoder()).resetSequenceNumber(); + + assertThat(keyExchanger.isKexDone()).isTrue(); + assertThat(keyExchanger.isKexOngoing()).isFalse(); + assertThat(keyExchanger.isStrictKex()).isTrue(); + assertThat(keyExchanger.isInitialKex()).isFalse(); + } + + @Test + void noClientFlagInSecondStrictKex() throws Exception { + keyExchanger.startKex(false); + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true)); + keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31)); + keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS)); + + ArgumentCaptor sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class); + when(transport.write(sshPacketCaptor.capture())).thenReturn(0L); + when(transport.isAuthenticated()).thenReturn(true); + + keyExchanger.startKex(false); + + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isTrue(); + assertThat(keyExchanger.isStrictKex()).isTrue(); + assertThat(keyExchanger.isInitialKex()).isFalse(); + + SSHPacket sshPacket = sshPacketCaptor.getValue(); + List kex = new Proposal(sshPacket).getKeyExchangeAlgorithms(); + assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com"); + } + + @Test + void serverFlagIsIgnoredInSecondKex() throws Exception { + keyExchanger.startKex(false); + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false)); + keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31)); + keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS)); + + ArgumentCaptor sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class); + when(transport.write(sshPacketCaptor.capture())).thenReturn(0L); + when(transport.isAuthenticated()).thenReturn(true); + + keyExchanger.startKex(false); + keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true)); + + assertThat(keyExchanger.isKexDone()).isFalse(); + assertThat(keyExchanger.isKexOngoing()).isTrue(); + assertThat(keyExchanger.isStrictKex()).isFalse(); + assertThat(keyExchanger.isInitialKex()).isFalse(); + + SSHPacket sshPacket = sshPacketCaptor.getValue(); + List kex = new Proposal(sshPacket).getKeyExchangeAlgorithms(); + assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com"); + } + + private SSHPacket getKexInitPacket(boolean withServerFlag) { + SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), true).getPacket(); + if (withServerFlag) { + int finalWpos = kexinitPacket.wpos(); + kexinitPacket.wpos(22); + kexinitPacket.putString("mock-kex,kex-strict-s-v00@openssh.com"); + kexinitPacket.wpos(finalWpos); + } + kexinitPacket.rpos(kexinitPacket.rpos() + 1); + return kexinitPacket; + } + +} diff --git a/src/test/java/net/schmizz/sshj/transport/TransportImplStrictKeyExchangeTest.java b/src/test/java/net/schmizz/sshj/transport/TransportImplStrictKeyExchangeTest.java new file mode 100644 index 000000000..58891be3a --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/TransportImplStrictKeyExchangeTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport; + +import java.lang.reflect.Field; + +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.EnumSource.Mode; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class TransportImplStrictKeyExchangeTest { + + private final Config config = new DefaultConfig(); + private final Transport transport = new TransportImpl(config); + private final KeyExchanger kexer = mock(KeyExchanger.class); + private final Decoder decoder = mock(Decoder.class); + + @BeforeEach + void setUp() throws Exception { + Field kexerField = TransportImpl.class.getDeclaredField("kexer"); + kexerField.setAccessible(true); + kexerField.set(transport, kexer); + Field decoderField = TransportImpl.class.getDeclaredField("decoder"); + decoderField.setAccessible(true); + decoderField.set(transport, decoder); + } + + @Test + void throwExceptionOnWrapDuringInitialKex() { + when(kexer.isInitialKex()).thenReturn(true); + when(decoder.isSequenceNumberAtMax()).thenReturn(true); + + assertThatExceptionOfType(TransportException.class).isThrownBy( + () -> transport.handle(Message.KEXINIT, new SSHPacket(Message.KEXINIT)) + ).satisfies(e -> { + assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED); + assertThat(e.getMessage()).isEqualTo("Sequence number of decoder is about to wrap during initial key exchange"); + }); + } + + @ParameterizedTest + @EnumSource(value = Message.class, mode = Mode.EXCLUDE, names = { + "DISCONNECT", "KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST" + }) + void forbidUnexpectedPacketsDuringStrictKeyExchange(Message message) { + when(kexer.isInitialKex()).thenReturn(true); + when(decoder.isSequenceNumberAtMax()).thenReturn(false); + when(kexer.isStrictKex()).thenReturn(true); + + assertThatExceptionOfType(TransportException.class).isThrownBy( + () -> transport.handle(message, new SSHPacket(message)) + ).satisfies(e -> { + assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED); + assertThat(e.getMessage()).isEqualTo("Unexpected packet type during initial strict key exchange"); + }); + } + + @ParameterizedTest + @EnumSource(value = Message.class, mode = Mode.INCLUDE, names = { + "KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST" + }) + void expectedPacketsDuringStrictKeyExchangeAreHandled(Message message) throws Exception { + when(kexer.isInitialKex()).thenReturn(true); + when(decoder.isSequenceNumberAtMax()).thenReturn(false); + when(kexer.isStrictKex()).thenReturn(true); + SSHPacket sshPacket = new SSHPacket(message); + + assertThatCode( + () -> transport.handle(message, sshPacket) + ).doesNotThrowAnyException(); + + verify(kexer).handle(message, sshPacket); + } + + @Test + void disconnectIsAllowedDuringStrictKeyExchange() { + when(kexer.isInitialKex()).thenReturn(true); + when(decoder.isSequenceNumberAtMax()).thenReturn(false); + when(kexer.isStrictKex()).thenReturn(true); + + SSHPacket sshPacket = new SSHPacket(); + sshPacket.putUInt32(DisconnectReason.SERVICE_NOT_AVAILABLE.toInt()); + sshPacket.putString("service is down for maintenance"); + + assertThatExceptionOfType(TransportException.class).isThrownBy( + () -> transport.handle(Message.DISCONNECT, sshPacket) + ).satisfies(e -> { + assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.SERVICE_NOT_AVAILABLE); + assertThat(e.getMessage()).isEqualTo("service is down for maintenance"); + }); + } + +} From dc6b20772b8b1931d755a849c874f59ff0e451e1 Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Tue, 2 Jan 2024 09:26:09 +0100 Subject: [PATCH 05/13] Prepare release 0.38.0 --- README.adoc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/README.adoc b/README.adoc index e0fcca987..4c28be41b 100644 --- a/README.adoc +++ b/README.adoc @@ -1,7 +1,7 @@ = sshj - SSHv2 library for Java Jeroen van Erp :sshj_groupid: com.hierynomus -:sshj_version: 0.37.0 +:sshj_version: 0.38.0 :source-highlighter: pygments image:https://github.com/hierynomus/sshj/actions/workflows/gradle.yml/badge.svg[link="https://github.com/hierynomus/sshj/actions/workflows/gradle.yml"] @@ -10,6 +10,8 @@ image:https://codecov.io/gh/hierynomus/sshj/branch/master/graph/badge.svg["codec image:http://www.javadoc.io/badge/com.hierynomus/sshj.svg?color=blue["JavaDocs", link="http://www.javadoc.io/doc/com.hierynomus/sshj"] image:https://maven-badges.herokuapp.com/maven-central/com.hierynomus/sshj/badge.svg["Maven Central",link="https://maven-badges.herokuapp.com/maven-central/com.hierynomus/sshj"] +WARNING: SSHJ versions up to and including 0.37.0 are vulnerable to https://nvd.nist.gov/vuln/detail/CVE-2023-48795[CVE-2023-48795 - Terrapin]. Please upgrade to 0.38.0 or higher. + To get started, have a look at one of the examples. Hopefully you will find the API pleasant to work with :) == Getting SSHJ @@ -46,7 +48,7 @@ If your project is built using another build tool that uses the Maven Central re In the `examples` directory, there is a separate Maven project that shows how the library can be used in some sample cases. If you want to run them, follow these guidelines: . Install http://maven.apache.org/[Maven 2.2.1] or up. -. Clone the Overthere repository. +. Clone the SSHJ repository. . Go into the `examples` directory and run the command `mvn eclipse:eclipse`. . Import the `examples` project into Eclipse. . Change the login details in the example classes (address, username and password) and run them! @@ -108,6 +110,14 @@ Issue tracker: https://github.com/hierynomus/sshj/issues Fork away! == Release history +SSHJ 0.38.0 (2024-01-02):: +* Mitigated CVE-2023-48795 - Terrapin + * Merged https://github.com/hierynomus/sshj/pull/917[#917]: Implement OpenSSH strict key exchange extension +* Merged https://github.com/hierynomus/sshj/pull/903[#903]: Fix for writing known hosts key string +* Merged https://github.com/hierynomus/sshj/pull/913[#913]: Prevent remote port forwarding buffers to grow without bounds +* Moved tess to JUnit5 +* Merged https://github.com/hierynomus/sshj/pull/827[#827]: Fallback to posix-rename@openssh.com extension if available +* Merged https://github.com/hierynomus/sshj/pull/904[#904]: Add ChaCha20-Poly1305 support for OpenSSH keys SSHJ 0.37.0 (2023-10-11):: * Merged https://github.com/hierynomus/sshj/pull/899[#899]: Add support for AES-GCM OpenSSH private keys * Merged https://github.com/hierynomus/sshj/pull/901[#901]: Fix ZLib compression bug From f94444bc5310995b2e2482407a82edc1dd73229a Mon Sep 17 00:00:00 2001 From: Pascal Schumacher Date: Tue, 2 Jan 2024 16:02:45 +0100 Subject: [PATCH 06/13] Fix typo in README.adoc (#920) --- README.adoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.adoc b/README.adoc index 4c28be41b..c372330cf 100644 --- a/README.adoc +++ b/README.adoc @@ -115,7 +115,7 @@ SSHJ 0.38.0 (2024-01-02):: * Merged https://github.com/hierynomus/sshj/pull/917[#917]: Implement OpenSSH strict key exchange extension * Merged https://github.com/hierynomus/sshj/pull/903[#903]: Fix for writing known hosts key string * Merged https://github.com/hierynomus/sshj/pull/913[#913]: Prevent remote port forwarding buffers to grow without bounds -* Moved tess to JUnit5 +* Moved tests to JUnit5 * Merged https://github.com/hierynomus/sshj/pull/827[#827]: Fallback to posix-rename@openssh.com extension if available * Merged https://github.com/hierynomus/sshj/pull/904[#904]: Add ChaCha20-Poly1305 support for OpenSSH keys SSHJ 0.37.0 (2023-10-11):: From 03f8b2224d18048c27d93d9cf84b59c427cdc8ca Mon Sep 17 00:00:00 2001 From: kegelh <1587490+kegelh@users.noreply.github.com> Date: Fri, 26 Jan 2024 13:36:29 +0100 Subject: [PATCH 07/13] known_hosts parsing does not ignore malformed base64 strings since 0.36.0 (#922) --- .../transport/verification/OpenSSHKnownHosts.java | 4 ++-- .../verification/OpenSSHKnownHostsTest.java | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index 7d71e1aa1..c41b83d7a 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -292,8 +292,8 @@ public KnownHostEntry parseEntry(String line) try { byte[] keyBytes = Base64.getDecoder().decode(sKey); key = new Buffer.PlainBuffer(keyBytes).readPublicKey(); - } catch (IOException ioe) { - log.warn("Error decoding Base64 key bytes", ioe); + } catch (IOException | IllegalArgumentException exception) { + log.warn("Error decoding Base64 key bytes", exception); return new BadHostEntry(line); } } else if (isBits(sType)) { diff --git a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java index e509656ea..01dbe2f58 100644 --- a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java +++ b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java @@ -23,11 +23,9 @@ import java.io.File; import java.io.IOException; -import java.lang.module.ModuleDescriptor.Opens; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.security.PublicKey; -import java.security.Security; import java.util.Base64; import java.util.stream.Stream; @@ -110,6 +108,16 @@ public void shouldNotFailOnBadBase64Entry() throws Exception { assertTrue(ohk.verify("host1", 22, k)); } + @Test + public void shouldNotFailOnMalformedBase64String() throws IOException { + File knownHosts = knownHosts( + "1.1.1.1 ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBA/CkqWXSlbdo7jPshvIWT/m3FAdpSIKUx/uTmz87ObpBxXsfF8aMSiwGMKHjqviTV4cG6F7vFf28ll+9CbGsbs=192\n" + ); + OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts); + assertEquals(1, ohk.entries().size()); + assertThat(ohk.entries().get(0)).isInstanceOf(OpenSSHKnownHosts.BadHostEntry.class); + } + @Test public void shouldMarkBadLineAndNotFail() throws Exception { File knownHosts = knownHosts( From c0d1519ee2deb7083aece6dfd0e4c0f6f8696504 Mon Sep 17 00:00:00 2001 From: Martin Volf <2805972+martin-volf@users.noreply.github.com> Date: Mon, 29 Jan 2024 11:49:43 +0100 Subject: [PATCH 08/13] connected sockets can be passed to the library (#925) * connected sockets can be passed to the library fixes hierynomus/sshj#924 Signed-off-by: Martin Volf * removed pointless socket check; test coverage improved Signed-off-by: Martin Volf * better test coverage Signed-off-by: Martin Volf --------- Signed-off-by: Martin Volf --- .../java/net/schmizz/sshj/SocketClient.java | 8 +- .../net/schmizz/sshj/ConnectedSocketTest.java | 105 ++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 src/test/java/net/schmizz/sshj/ConnectedSocketTest.java diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index d7971243a..e4809e0d2 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -65,7 +65,9 @@ public void connect(String hostname, int port) throws IOException { this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -104,7 +106,9 @@ public void connect(InetAddress host) throws IOException { public void connect(InetAddress host, int port) throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java new file mode 100644 index 000000000..1424d62dd --- /dev/null +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -0,0 +1,105 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.apache.sshd.server.SshServer; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.stream.Stream; + +import javax.net.SocketFactory; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + + +public class ConnectedSocketTest { + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + @BeforeEach + public void setupClient() throws IOException { + SSHClient defaultClient = fixture.setupDefaultClient(); + } + + private static interface Connector { + void connect(SshServerExtension fx) throws IOException; + } + + private static void connectViaHostname(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + fx.getClient().connect("localhost", server.getPort()); + } + + private static void connectViaAddr(SshServerExtension fx) throws IOException { + SshServer server = fx.getServer(); + InetAddress addr = InetAddress.getByName(server.getHost()); + fx.getClient().connect(addr, server.getPort()); + } + + private static Stream connectMethods() { + return Stream.of(fx -> connectViaHostname(fx), fx -> connectViaAddr(fx)); + } + + @ParameterizedTest + @MethodSource("connectMethods") + public void connectsIfUnconnected(Connector connector) { + assertDoesNotThrow(() -> connector.connect(fixture)); + } + + @ParameterizedTest + @MethodSource("connectMethods") + public void handlesConnected(Connector connector) throws IOException { + Socket socket = SocketFactory.getDefault().createSocket(); + SocketFactory factory = new SocketFactory() { + @Override + public Socket createSocket() { + return socket; + } + @Override + public Socket createSocket(InetAddress host, int port) { + return socket; + } + @Override + public Socket createSocket(InetAddress address, int port, + InetAddress localAddress, int localPort) { + return socket; + } + @Override + public Socket createSocket(String host, int port) { + return socket; + } + @Override + public Socket createSocket(String host, int port, + InetAddress localHost, int localPort) { + return socket; + } + }; + socket.connect(new InetSocketAddress("localhost", fixture.getServer().getPort())); + fixture.getClient().setSocketFactory(factory); + assertDoesNotThrow(() -> connector.connect(fixture)); + } +} From 70af58d19934d4dbc4b14008c2e328f7fea8bae9 Mon Sep 17 00:00:00 2001 From: Vladimir Lagunov Date: Mon, 15 Apr 2024 09:23:53 +0200 Subject: [PATCH 09/13] Wrap IllegalArgumentException thrown by Base64 decoder (#936) * Wrap IllegalArgumentException thrown by Base64 decoder Some time ago, there had been `net.schmizz.sshj.common.Base64`. This class used to throw `IOException` in case of any problem. Although `IOException` isn't an appropriate class for indicating on parsing issues, a lot of code has been expecting `IOException` from Base64. Once, the old Base64 decoder was replaced with the one, bundled into Java 14 (see f35c2bd4ce540cc65ee114102d9395034189915f). Copy-paste elimination and switching to standard implementations is undoubtedly a good decision. Unfortunately, `java.util.Base64.Decoder` brought a pesky issue. It throws `IllegalArgumentException` in case of any problem. Since it is an unchecked exception, it was quite challenging to notice it. It's especially challenging because the error appears during processing malformed base64 strings. So, a lot of places in the code kept expecting `IOException`. Sudden `IllegalArgumentException` led to authentication termination in cases where everything used to work perfectly. One of such issues is already found and fixed: 03f8b2224d18048c27d93d9cf84b59c427cdc8ca This commit represents a work, based on revising every change made in f35c2bd4ce540cc65ee114102d9395034189915f. It should fix all other similar issues. * squash! Wrap IllegalArgumentException thrown by Base64 decoder Rename Base64DecodeError -> Base64DecodingException * squash! Wrap IllegalArgumentException thrown by Base64 decoder A better warning message in KnownHostMatchers * squash! Wrap IllegalArgumentException thrown by Base64 decoder A better error message in OpenSSHKeyFileUtil * squash! Wrap IllegalArgumentException thrown by Base64 decoder A better error message in OpenSSHKeyV1KeyFile * squash! Wrap IllegalArgumentException thrown by Base64 decoder Get rid of unnecessary `throws IOException` in Base64Decoder * squash! Wrap IllegalArgumentException thrown by Base64 decoder Better error messages in OpenSSHKeyFileUtil and PuTTYKeyFile --- .../verification/KnownHostMatchers.java | 19 ++++- .../keyprovider/OpenSSHKeyFileUtil.java | 8 +- .../keyprovider/OpenSSHKeyV1KeyFile.java | 12 +-- .../schmizz/sshj/common/Base64Decoder.java | 47 ++++++++++++ .../sshj/common/Base64DecodingException.java | 28 +++++++ .../verification/OpenSSHKnownHosts.java | 12 +-- .../userauth/keyprovider/PuTTYKeyFile.java | 50 +++++++------ .../verification/OpenSSHKnownHostsTest.java | 46 ++++++++---- .../keyprovider/CorruptedPublicKeyTest.java | 73 +++++++++++++++++++ .../sshj/keyprovider/PuTTYKeyFileTest.java | 61 ++++++++++++++++ .../net/schmizz/sshj/util/CorruptBase64.java | 42 +++++++++++ 11 files changed, 335 insertions(+), 63 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/common/Base64Decoder.java create mode 100644 src/main/java/net/schmizz/sshj/common/Base64DecodingException.java create mode 100644 src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java create mode 100644 src/test/java/net/schmizz/sshj/util/CorruptBase64.java diff --git a/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java b/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java index 37fdaef5c..cb6645637 100644 --- a/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java +++ b/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java @@ -15,6 +15,8 @@ */ package com.hierynomus.sshj.transport.verification; +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.transport.mac.MAC; @@ -26,9 +28,13 @@ import java.util.regex.Pattern; import com.hierynomus.sshj.transport.mac.Macs; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class KnownHostMatchers { + private static final Logger log = LoggerFactory.getLogger(KnownHostMatchers.class); + public static HostMatcher createMatcher(String hostEntry) throws SSHException { if (hostEntry.contains(",")) { return new AnyHostMatcher(hostEntry); @@ -80,17 +86,22 @@ private static class HashedHostMatcher implements HostMatcher { @Override public boolean match(String hostname) throws IOException { - return hash.equals(hashHost(hostname)); + try { + return hash.equals(hashHost(hostname)); + } catch (Base64DecodingException err) { + log.warn("Hostname [{}] not matched: salt decoding failed", hostname, err); + return false; + } } - private String hashHost(String host) throws IOException { + private String hashHost(String host) throws IOException, Base64DecodingException { sha1.init(getSaltyBytes()); return "|1|" + salt + "|" + Base64.getEncoder().encodeToString(sha1.doFinal(host.getBytes(IOUtils.UTF8))); } - private byte[] getSaltyBytes() { + private byte[] getSaltyBytes() throws IOException, Base64DecodingException { if (saltyBytes == null) { - saltyBytes = Base64.getDecoder().decode(salt); + saltyBytes = Base64Decoder.decode(salt); } return saltyBytes; } diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java index edb56ef36..94802c419 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java @@ -15,6 +15,8 @@ */ package com.hierynomus.sshj.userauth.keyprovider; +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.KeyType; @@ -23,7 +25,6 @@ import java.io.IOException; import java.io.Reader; import java.security.PublicKey; -import java.util.Base64; public class OpenSSHKeyFileUtil { private OpenSSHKeyFileUtil() { @@ -54,9 +55,10 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException { if (!keydata.isEmpty()) { String[] parts = keydata.trim().split("\\s+"); if (parts.length >= 2) { + byte[] decodedPublicKey = Base64Decoder.decode(parts[1]); return new ParsedPubKey( KeyType.fromString(parts[0]), - new Buffer.PlainBuffer(Base64.getDecoder().decode(parts[1])).readPublicKey() + new Buffer.PlainBuffer(decodedPublicKey).readPublicKey() ); } else { throw new IOException("Got line with only one column"); @@ -64,6 +66,8 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException { } } throw new IOException("Public key file is blank"); + } catch (Base64DecodingException err) { + throw new IOException("Public key decoding failed", err); } finally { br.close(); } diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java index 9229fa4af..5d89356ff 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java @@ -23,13 +23,8 @@ import net.i2p.crypto.eddsa.EdDSAPrivateKey; import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec; -import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.common.Buffer.PlainBuffer; -import net.schmizz.sshj.common.ByteArrayUtils; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.SSHRuntimeException; -import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.transport.cipher.Cipher; import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider; import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider; @@ -55,7 +50,6 @@ import java.security.spec.ECPrivateKeySpec; import java.security.spec.RSAPrivateCrtKeySpec; import java.util.Arrays; -import java.util.Base64; import java.util.HashMap; import java.util.Map; @@ -124,7 +118,7 @@ protected KeyPair readKeyPair() throws IOException { try { if (checkHeader(reader)) { final String encodedPrivateKey = readEncodedKey(reader); - byte[] decodedPrivateKey = Base64.getDecoder().decode(encodedPrivateKey); + byte[] decodedPrivateKey = Base64Decoder.decode(encodedPrivateKey); final PlainBuffer bufferedPrivateKey = new PlainBuffer(decodedPrivateKey); return readDecodedKeyPair(bufferedPrivateKey); } else { @@ -133,6 +127,8 @@ protected KeyPair readKeyPair() throws IOException { } } catch (final GeneralSecurityException e) { throw new SSHRuntimeException("Read OpenSSH Version 1 Key failed", e); + } catch (Base64DecodingException e) { + throw new SSHRuntimeException("Private Key decoding failed", e); } finally { IOUtils.closeQuietly(reader); } diff --git a/src/main/java/net/schmizz/sshj/common/Base64Decoder.java b/src/main/java/net/schmizz/sshj/common/Base64Decoder.java new file mode 100644 index 000000000..e29608ad1 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/Base64Decoder.java @@ -0,0 +1,47 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.common; + +import java.io.IOException; +import java.util.Base64; + +/** + *

Wraps {@link java.util.Base64.Decoder} in order to wrap unchecked {@code IllegalArgumentException} thrown by + * the default Java Base64 decoder here and there.

+ * + *

Please use this class instead of {@link java.util.Base64.Decoder}.

+ */ +public class Base64Decoder { + private Base64Decoder() { + } + + public static byte[] decode(byte[] source) throws Base64DecodingException { + try { + return Base64.getDecoder().decode(source); + } catch (IllegalArgumentException err) { + throw new Base64DecodingException(err); + } + } + + public static byte[] decode(String src) throws Base64DecodingException { + try { + return Base64.getDecoder().decode(src); + } catch (IllegalArgumentException err) { + throw new Base64DecodingException(err); + } + } +} diff --git a/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java b/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java new file mode 100644 index 000000000..cc18ead7e --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java @@ -0,0 +1,28 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.common; + +/** + * A checked wrapper for all {@link IllegalArgumentException}, thrown by {@link java.util.Base64.Decoder}. + * + * @see Base64Decoder + */ +public class Base64DecodingException extends Exception { + public Base64DecodingException(IllegalArgumentException cause) { + super("Failed to decode base64: " + cause.getMessage(), cause); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index c41b83d7a..a58219089 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -18,13 +18,7 @@ import com.hierynomus.sshj.common.KeyAlgorithm; import com.hierynomus.sshj.transport.verification.KnownHostMatchers; import com.hierynomus.sshj.userauth.certificate.Certificate; -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.LoggerFactory; -import net.schmizz.sshj.common.SSHException; -import net.schmizz.sshj.common.SSHRuntimeException; -import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.common.*; import org.slf4j.Logger; import java.io.BufferedOutputStream; @@ -290,9 +284,9 @@ public KnownHostEntry parseEntry(String line) if (type != KeyType.UNKNOWN) { final String sKey = split[i++]; try { - byte[] keyBytes = Base64.getDecoder().decode(sKey); + byte[] keyBytes = Base64Decoder.decode(sKey); key = new Buffer.PlainBuffer(keyBytes).readPublicKey(); - } catch (IOException | IllegalArgumentException exception) { + } catch (IOException | Base64DecodingException exception) { log.warn("Error decoding Base64 key bytes", exception); return new BadHostEntry(line); } diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java index 9794da0fa..444c222a5 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java @@ -22,9 +22,7 @@ import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec; import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec; -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.userauth.password.PasswordUtils; import org.bouncycastle.asn1.nist.NISTNamedCurves; import org.bouncycastle.asn1.x9.X9ECParameters; @@ -42,7 +40,6 @@ import java.security.*; import java.security.spec.*; import java.util.Arrays; -import java.util.Base64; import java.util.HashMap; import java.util.Map; @@ -240,29 +237,34 @@ protected void parseKeyPair() throws IOException { if (this.keyFileVersion == null) { throw new IOException("Invalid key file format: missing \"PuTTY-User-Key-File-?\" entry"); } - // Retrieve keys from payload - publicKey = Base64.getDecoder().decode(payload.get("Public-Lines")); - if (this.isEncrypted()) { - final char[] passphrase; - if (pwdf != null) { - passphrase = pwdf.reqPassword(resource); - } else { - passphrase = "".toCharArray(); - } - try { - privateKey = this.decrypt(Base64.getDecoder().decode(payload.get("Private-Lines")), passphrase); - Mac mac; - if (this.keyFileVersion <= 2) { - mac = this.prepareVerifyMacV2(passphrase); + try { + // Retrieve keys from payload + publicKey = Base64Decoder.decode(payload.get("Public-Lines")); + if (this.isEncrypted()) { + final char[] passphrase; + if (pwdf != null) { + passphrase = pwdf.reqPassword(resource); } else { - mac = this.prepareVerifyMacV3(); + passphrase = "".toCharArray(); + } + try { + privateKey = this.decrypt(Base64Decoder.decode(payload.get("Private-Lines")), passphrase); + Mac mac; + if (this.keyFileVersion <= 2) { + mac = this.prepareVerifyMacV2(passphrase); + } else { + mac = this.prepareVerifyMacV3(); + } + this.verify(mac); + } finally { + PasswordUtils.blankOut(passphrase); } - this.verify(mac); - } finally { - PasswordUtils.blankOut(passphrase); + } else { + privateKey = Base64Decoder.decode(payload.get("Private-Lines")); } - } else { - privateKey = Base64.getDecoder().decode(payload.get("Private-Lines")); + } + catch (Base64DecodingException e) { + throw new IOException("PuTTY key decoding failed", e); } } diff --git a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java index 01dbe2f58..8bfebdaeb 100644 --- a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java +++ b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java @@ -15,11 +15,16 @@ */ package com.hierynomus.sshj.transport.verification; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.assertj.core.api.Assertions.*; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; +import net.schmizz.sshj.util.KeyUtil; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.io.File; import java.io.IOException; @@ -29,17 +34,8 @@ import java.util.Base64; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.SecurityUtils; -import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; -import net.schmizz.sshj.util.KeyUtil; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; public class OpenSSHKnownHostsTest { @TempDir @@ -118,6 +114,24 @@ public void shouldNotFailOnMalformedBase64String() throws IOException { assertThat(ohk.entries().get(0)).isInstanceOf(OpenSSHKnownHosts.BadHostEntry.class); } + @Test + public void shouldNotFailOnMalformeSaltBase64String() throws IOException { + // A record with broken base64 inside the salt part of the hash. + // No matter how it could be generated, such broken strings must not cause unexpected errors. + String hostName = "example.com"; + File knownHosts = knownHosts( + "|1|2gujgGa6gJnK7wGPCX8zuGttvCMXX|Oqkbjtxd9RFxKQv6y3l3GIxLNiU= ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGVVnyoAD5/uWiiuTSM3RuW8dEWRrqOXYobAMKHhAA6kuOBoPK+LoAYyUcN26bdMiCxg+VOaLHxPNWv5SlhbMWw=\n" + ); + OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts); + assertEquals(1, ohk.entries().size()); + + // Some random valid public key. It doesn't matter for the test if it matches the broken host key record or not. + PublicKey k = new Buffer.PlainBuffer(Base64.getDecoder().decode( + "AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBLTjA7hduYGmvV9smEEsIdGLdghSPD7kL8QarIIOkeXmBh+LTtT/T1K+Ot/rmXCZsP8hoUXxbvN+Tks440Ci0ck=")) + .readPublicKey(); + assertFalse(ohk.verify(hostName, 22, k)); + } + @Test public void shouldMarkBadLineAndNotFail() throws Exception { File knownHosts = knownHosts( diff --git a/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java b/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java new file mode 100644 index 000000000..cf5a6c23c --- /dev/null +++ b/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.keyprovider; + +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.util.CorruptBase64; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +public class CorruptedPublicKeyTest { + private final Path keyRoot = Path.of("src/test/resources"); + + @TempDir + public Path tempDir; + + @ParameterizedTest + @CsvSource({ + "keyformats/ecdsa_opensshv1,", + "keyformats/openssh,", + "keytypes/test_ecdsa_nistp521_2,", + "keytypes/ed25519_protected, sshjtest", + }) + public void corruptedPublicKey(String privateKeyFileName, String passphrase) throws IOException { + Files.createDirectories(tempDir.resolve(privateKeyFileName).getParent()); + Files.copy(keyRoot.resolve(privateKeyFileName), tempDir.resolve(privateKeyFileName)); + + { + String publicKeyText; + try (var reader = new BufferedReader(new FileReader( + keyRoot.resolve(privateKeyFileName + ".pub").toFile()))) { + publicKeyText = reader.readLine(); + } + + String[] parts = publicKeyText.split("\\s+"); + parts[1] = CorruptBase64.corruptBase64(parts[1]); + + try (var writer = new FileWriter(tempDir.resolve(privateKeyFileName + ".pub").toFile())) { + writer.write(String.join(" ", parts)); + } + } + + // Must not throw an exception. + try (var sshClient = new SSHClient()) { + sshClient.loadKeys( + tempDir.resolve(privateKeyFileName).toString(), + Optional.ofNullable(passphrase).map(String::toCharArray).orElse(null) + ).getPublic(); + } + } +} diff --git a/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java b/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java index cfed5537f..3b1d5218d 100644 --- a/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java +++ b/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java @@ -18,15 +18,19 @@ import com.hierynomus.sshj.userauth.keyprovider.OpenSSHKeyV1KeyFile; import net.schmizz.sshj.userauth.keyprovider.PKCS8KeyFile; import net.schmizz.sshj.userauth.keyprovider.PuTTYKeyFile; +import net.schmizz.sshj.util.CorruptBase64; import net.schmizz.sshj.util.UnitTestPasswordFinder; import org.junit.jupiter.api.Test; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; +import java.util.Objects; +import static java.lang.Math.min; import static org.junit.jupiter.api.Assertions.*; public class PuTTYKeyFileTest { @@ -558,4 +562,61 @@ public void testWrongPassphraseDsa() throws Exception { assertNull(key.getPrivate()); }); } + + @Test + public void corruptedPublicLines() throws Exception { + assertThrows(IOException.class, () -> { + PuTTYKeyFile key = new PuTTYKeyFile(); + key.init(new StringReader(corruptBase64InPuttyKey(ppk2048, "Public-Lines: "))); + key.getPublic(); + }); + } + + @Test + public void corruptedPrivateLines() throws Exception { + assertThrows(IOException.class, () -> { + PuTTYKeyFile key = new PuTTYKeyFile(); + key.init(new StringReader(corruptBase64InPuttyKey(ppk2048, "Private-Lines: "))); + key.getPublic(); + }); + } + + private String corruptBase64InPuttyKey( + @SuppressWarnings("SameParameterValue") String source, + String sectionPrefix + ) throws IOException { + try (var reader = new BufferedReader(new StringReader(source))) { + StringBuilder result = new StringBuilder(); + while (true) { + String line = reader.readLine(); + if (line == null) { + break; + } else if (line.startsWith(sectionPrefix)) { + int base64LineCount = Integer.parseInt(line.substring(sectionPrefix.length())); + StringBuilder base64 = new StringBuilder(); + for (int i = 0; i < base64LineCount; ++i) { + base64.append(Objects.requireNonNull(reader.readLine())); + } + String corruptedBase64 = CorruptBase64.corruptBase64(base64.toString()); + + // 64 is the length of base64 lines in PuTTY keys generated by puttygen. + // It's not clear if it's some standard or not. + // It doesn't match the MIME Base64 standard. + int chunkSize = 64; + + result.append(sectionPrefix); + result.append((corruptedBase64.length() + chunkSize - 1) / chunkSize); + result.append('\n'); + for (int offset = 0; offset < corruptedBase64.length(); offset += chunkSize) { + result.append(corruptedBase64, offset, min(corruptedBase64.length(), offset + chunkSize)); + result.append('\n'); + } + } else { + result.append(line); + result.append('\n'); + } + } + return result.toString(); + } + } } diff --git a/src/test/java/net/schmizz/sshj/util/CorruptBase64.java b/src/test/java/net/schmizz/sshj/util/CorruptBase64.java new file mode 100644 index 000000000..edab8a5bc --- /dev/null +++ b/src/test/java/net/schmizz/sshj/util/CorruptBase64.java @@ -0,0 +1,42 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.util; + +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; + +import java.io.IOException; + +public class CorruptBase64 { + private CorruptBase64() { + } + + public static String corruptBase64(String source) throws IOException { + while (true) { + try { + Base64Decoder.decode(source); + } catch (Base64DecodingException e) { + return source; + } + + if (source.endsWith("=")) { + source = source.substring(0, source.length() - 1); + } + source += "X"; + } + } +} From 81d77d277c96e24d76f705fa8cfc5d8daea13e44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henning=20P=C3=B6ttker?= Date: Mon, 15 Apr 2024 09:29:06 +0200 Subject: [PATCH 10/13] Don't send keep alive signals before kex is done (#934) Otherwise, they could interfere with strict key exchange. Co-authored-by: Jeroen van Erp --- .../transport/kex/StrictKeyExchangeTest.java | 56 ++++++++++++++++--- src/main/java/net/schmizz/sshj/SSHClient.java | 2 +- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java index 2abe71a72..9d207c0ec 100644 --- a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -18,15 +18,26 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.Stream; import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.read.ListAppender; import com.hierynomus.sshj.SshdContainer; +import net.schmizz.keepalive.KeepAlive; +import net.schmizz.keepalive.KeepAliveProvider; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.LoggerFactory; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -62,14 +73,27 @@ private void setUpLogger(String className) { watchedLoggers.add(logger); } - @Test - void strictKeyExchange() throws Throwable { - try (SSHClient client = sshd.getConnectedClient()) { + private static Stream strictKeyExchange() { + Config defaultConfig = new DefaultConfig(); + Config heartbeaterConfig = new DefaultConfig(); + heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() { + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new HotLoopHeartbeater(connection); + } + }); + return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of); + } + + @MethodSource + @ParameterizedTest + void strictKeyExchange(Config config) throws Throwable { + try (SSHClient client = sshd.getConnectedClient(config)) { client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); assertTrue(client.isAuthenticated()); } List keyExchangerLogs = getLogs("KeyExchanger"); - assertThat(keyExchangerLogs).containsSequence( + assertThat(keyExchangerLogs).contains( "Initiating key exchange", "Sending SSH_MSG_KEXINIT", "Received SSH_MSG_KEXINIT", @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable { List decoderLogs = getLogs("Decoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(decoderLogs).containsExactly( + assertThat(decoderLogs).startsWith( "Received packet #0", "Received packet #1", "Received packet #2", @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable { List encoderLogs = getLogs("Encoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(encoderLogs).containsExactly( + assertThat(encoderLogs).startsWith( "Encoding packet #0", "Encoding packet #1", "Encoding packet #2", @@ -108,4 +132,22 @@ private List getLogs(String className) { .collect(Collectors.toList()); } + private static class HotLoopHeartbeater extends KeepAlive { + + HotLoopHeartbeater(ConnectionImpl conn) { + super(conn, "sshj-Heartbeater"); + } + + @Override + public boolean isEnabled() { + return true; + } + + @Override + protected void doKeepAlive() throws TransportException { + conn.getTransport().write(new SSHPacket(Message.IGNORE)); + } + + } + } diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index dd0e38170..78b91c5f7 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -804,12 +804,12 @@ protected void onConnect() throws IOException { super.onConnect(); trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream()); + doKex(); final KeepAlive keepAliveThread = conn.getKeepAlive(); if (keepAliveThread.isEnabled()) { ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); keepAliveThread.start(); } - doKex(); } /** From 624fe839cba84764e7c429a6243cf3107c68e995 Mon Sep 17 00:00:00 2001 From: Lucas <16666115+EndzeitBegins@users.noreply.github.com> Date: Mon, 15 Apr 2024 20:18:15 +0200 Subject: [PATCH 11/13] Support premature termination of listing (#928) * Support premature termination of listing * Added license header + small refactor --------- Co-authored-by: Jeroen van Erp --- .../sftp/RemoteResourceFilterConverter.java | 30 ++++++++++ .../sshj/sftp/RemoteResourceSelector.java | 49 +++++++++++++++++ .../schmizz/sshj/sftp/RemoteDirectory.java | 55 +++++++++++++------ .../net/schmizz/sshj/sftp/SFTPClient.java | 17 ++++-- .../schmizz/sshj/sftp/StatefulSFTPClient.java | 22 +++++--- .../sshj/sftp/SFTPClientSpec.groovy | 55 +++++++++++++++++++ 6 files changed, 196 insertions(+), 32 deletions(-) create mode 100644 src/main/java/com/hierynomus/sshj/sftp/RemoteResourceFilterConverter.java create mode 100644 src/main/java/com/hierynomus/sshj/sftp/RemoteResourceSelector.java diff --git a/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceFilterConverter.java b/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceFilterConverter.java new file mode 100644 index 000000000..e73aec473 --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceFilterConverter.java @@ -0,0 +1,30 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.sftp; + +import com.hierynomus.sshj.sftp.RemoteResourceSelector.Result; +import net.schmizz.sshj.sftp.RemoteResourceFilter; + +public class RemoteResourceFilterConverter { + + public static RemoteResourceSelector selectorFrom(RemoteResourceFilter filter) { + if (filter == null) { + return RemoteResourceSelector.ALL; + } + + return resource -> filter.accept(resource) ? Result.ACCEPT : Result.CONTINUE; + } +} diff --git a/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceSelector.java b/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceSelector.java new file mode 100644 index 000000000..3a9e4993d --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/sftp/RemoteResourceSelector.java @@ -0,0 +1,49 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.sftp; + +import net.schmizz.sshj.sftp.RemoteResourceInfo; + +public interface RemoteResourceSelector { + public static RemoteResourceSelector ALL = new RemoteResourceSelector() { + @Override + public Result select(RemoteResourceInfo resource) { + return Result.ACCEPT; + } + }; + + enum Result { + /** + * Accept the remote resource and add it to the result. + */ + ACCEPT, + + /** + * Do not add the remote resource to the result and continue with the next. + */ + CONTINUE, + + /** + * Do not add the remote resource to the result and stop further execution. + */ + BREAK; + } + + /** + * Decide whether the remote resource should be included in the result and whether execution should continue. + */ + Result select(RemoteResourceInfo resource); +} diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteDirectory.java b/src/main/java/net/schmizz/sshj/sftp/RemoteDirectory.java index 4f9718ed7..6e4e63757 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteDirectory.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteDirectory.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.sftp; +import com.hierynomus.sshj.sftp.RemoteResourceSelector; import net.schmizz.sshj.sftp.Response.StatusCode; import java.io.IOException; @@ -22,6 +23,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; +import static com.hierynomus.sshj.sftp.RemoteResourceFilterConverter.selectorFrom; + public class RemoteDirectory extends RemoteResource { @@ -31,37 +34,55 @@ public RemoteDirectory(SFTPEngine requester, String path, byte[] handle) { public List scan(RemoteResourceFilter filter) throws IOException { - List rri = new LinkedList(); - // TODO: Remove GOTO! - loop: - for (; ; ) { - final Response res = requester.request(newRequest(PacketType.READDIR)) + return scan(selectorFrom(filter)); + } + + public List scan(RemoteResourceSelector selector) + throws IOException { + if (selector == null) { + selector = RemoteResourceSelector.ALL; + } + + List remoteResourceInfos = new LinkedList<>(); + + while (true) { + final Response response = requester.request(newRequest(PacketType.READDIR)) .retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); - switch (res.getType()) { + switch (response.getType()) { case NAME: - final int count = res.readUInt32AsInt(); + final int count = response.readUInt32AsInt(); for (int i = 0; i < count; i++) { - final String name = res.readString(requester.sub.getRemoteCharset()); - res.readString(); // long name - IGNORED - shdve never been in the protocol - final FileAttributes attrs = res.readFileAttributes(); + final String name = response.readString(requester.sub.getRemoteCharset()); + response.readString(); // long name - IGNORED - shdve never been in the protocol + final FileAttributes attrs = response.readFileAttributes(); final PathComponents comps = requester.getPathHelper().getComponents(path, name); final RemoteResourceInfo inf = new RemoteResourceInfo(comps, attrs); - if (!(".".equals(name) || "..".equals(name)) && (filter == null || filter.accept(inf))) { - rri.add(inf); + + if (".".equals(name) || "..".equals(name)) { + continue; + } + + final RemoteResourceSelector.Result selectionResult = selector.select(inf); + switch (selectionResult) { + case ACCEPT: + remoteResourceInfos.add(inf); + break; + case CONTINUE: + continue; + case BREAK: + return remoteResourceInfos; } } break; case STATUS: - res.ensureStatusIs(StatusCode.EOF); - break loop; + response.ensureStatusIs(StatusCode.EOF); + return remoteResourceInfos; default: - throw new SFTPException("Unexpected packet: " + res.getType()); + throw new SFTPException("Unexpected packet: " + response.getType()); } } - return rri; } - } diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPClient.java b/src/main/java/net/schmizz/sshj/sftp/SFTPClient.java index af9d70650..bc94cd53d 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPClient.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPClient.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.sftp; +import com.hierynomus.sshj.sftp.RemoteResourceSelector; import net.schmizz.sshj.connection.channel.direct.SessionFactory; import net.schmizz.sshj.xfer.FilePermission; import net.schmizz.sshj.xfer.LocalDestFile; @@ -25,6 +26,8 @@ import java.io.IOException; import java.util.*; +import static com.hierynomus.sshj.sftp.RemoteResourceFilterConverter.selectorFrom; + public class SFTPClient implements Closeable { @@ -57,16 +60,18 @@ public SFTPFileTransfer getFileTransfer() { public List ls(String path) throws IOException { - return ls(path, null); + return ls(path, RemoteResourceSelector.ALL); } public List ls(String path, RemoteResourceFilter filter) throws IOException { - final RemoteDirectory dir = engine.openDir(path); - try { - return dir.scan(filter); - } finally { - dir.close(); + return ls(path, selectorFrom(filter)); + } + + public List ls(String path, RemoteResourceSelector selector) + throws IOException { + try (RemoteDirectory dir = engine.openDir(path)) { + return dir.scan(selector == null ? RemoteResourceSelector.ALL : selector); } } diff --git a/src/main/java/net/schmizz/sshj/sftp/StatefulSFTPClient.java b/src/main/java/net/schmizz/sshj/sftp/StatefulSFTPClient.java index 4f31d0e8d..872dae8e9 100644 --- a/src/main/java/net/schmizz/sshj/sftp/StatefulSFTPClient.java +++ b/src/main/java/net/schmizz/sshj/sftp/StatefulSFTPClient.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.sftp; +import com.hierynomus.sshj.sftp.RemoteResourceSelector; import net.schmizz.sshj.connection.channel.direct.SessionFactory; import net.schmizz.sshj.xfer.LocalDestFile; import net.schmizz.sshj.xfer.LocalSourceFile; @@ -23,6 +24,8 @@ import java.util.List; import java.util.Set; +import static com.hierynomus.sshj.sftp.RemoteResourceFilterConverter.selectorFrom; + public class StatefulSFTPClient extends SFTPClient { @@ -57,7 +60,7 @@ public synchronized void cd(String dirname) public synchronized List ls() throws IOException { - return ls(cwd, null); + return ls(cwd, RemoteResourceSelector.ALL); } public synchronized List ls(RemoteResourceFilter filter) @@ -70,20 +73,21 @@ public synchronized String pwd() return super.canonicalize(cwd); } - @Override public List ls(String path) throws IOException { - return ls(path, null); + return ls(path, RemoteResourceSelector.ALL); } - @Override public List ls(String path, RemoteResourceFilter filter) throws IOException { - final RemoteDirectory dir = getSFTPEngine().openDir(cwdify(path)); - try { - return dir.scan(filter); - } finally { - dir.close(); + return ls(path, selectorFrom(filter)); + } + + @Override + public List ls(String path, RemoteResourceSelector selector) + throws IOException { + try (RemoteDirectory dir = getSFTPEngine().openDir(cwdify(path))) { + return dir.scan(selector == null ? RemoteResourceSelector.ALL : selector); } } diff --git a/src/test/groovy/com/hierynomus/sshj/sftp/SFTPClientSpec.groovy b/src/test/groovy/com/hierynomus/sshj/sftp/SFTPClientSpec.groovy index b3f6774c7..421ca60ad 100644 --- a/src/test/groovy/com/hierynomus/sshj/sftp/SFTPClientSpec.groovy +++ b/src/test/groovy/com/hierynomus/sshj/sftp/SFTPClientSpec.groovy @@ -19,6 +19,7 @@ import com.hierynomus.sshj.test.SshServerExtension import com.hierynomus.sshj.test.util.FileUtil import net.schmizz.sshj.SSHClient import net.schmizz.sshj.sftp.FileMode +import net.schmizz.sshj.sftp.RemoteResourceInfo import net.schmizz.sshj.sftp.SFTPClient import org.junit.jupiter.api.extension.RegisterExtension import spock.lang.Specification @@ -206,6 +207,60 @@ class SFTPClientSpec extends Specification { attrs.type == FileMode.Type.DIRECTORY } + def "should support premature termination of listing"() { + given: + SSHClient sshClient = fixture.setupConnectedDefaultClient() + sshClient.authPassword("test", "test") + SFTPClient sftpClient = sshClient.newSFTPClient() + + final Path source = Files.createDirectory(temp.resolve("source")).toAbsolutePath() + final Path destination = Files.createDirectory(temp.resolve("destination")).toAbsolutePath() + final Path firstFile = Files.writeString(source.resolve("a_first.txt"), "first") + final Path secondFile = Files.writeString(source.resolve("b_second.txt"), "second") + final Path thirdFile = Files.writeString(source.resolve("c_third.txt"), "third") + final Path fourthFile = Files.writeString(source.resolve("d_fourth.txt"), "fourth") + sftpClient.put(firstFile.toString(), destination.resolve(firstFile.fileName).toString()) + sftpClient.put(secondFile.toString(), destination.resolve(secondFile.fileName).toString()) + sftpClient.put(thirdFile.toString(), destination.resolve(thirdFile.fileName).toString()) + sftpClient.put(fourthFile.toString(), destination.resolve(fourthFile.fileName).toString()) + + def filesListed = 0 + RemoteResourceInfo expectedFile = null + RemoteResourceSelector limitingSelector = new RemoteResourceSelector() { + @Override + RemoteResourceSelector.Result select(RemoteResourceInfo resource) { + filesListed += 1 + + switch(filesListed) { + case 1: + return RemoteResourceSelector.Result.CONTINUE + case 2: + expectedFile = resource + return RemoteResourceSelector.Result.ACCEPT + case 3: + return RemoteResourceSelector.Result.BREAK + default: + throw new AssertionError((Object) "Should NOT select any more resources") + } + } + } + + when: + def listingResult = sftpClient + .ls(destination.toString(), limitingSelector); + + then: + // first should be skipped by CONTINUE + listingResult.contains(expectedFile) // second should be included by ACCEPT + // third should be skipped by BREAK + // fourth should be skipped by preceding BREAK + listingResult.size() == 1 + + cleanup: + sftpClient.close() + sshClient.disconnect() + } + private void doUpload(File src, File dest) throws IOException { SSHClient sshClient = fixture.setupConnectedDefaultClient() sshClient.authPassword("test", "test") From 586a66420ed0bb3dc9588f17f3767e022795a1fc Mon Sep 17 00:00:00 2001 From: Eric Vigeant Date: Mon, 15 Apr 2024 14:31:54 -0400 Subject: [PATCH 12/13] Close Session when closing SCPEngine or SFTPEngine (#926) Co-authored-by: Jeroen van Erp --- src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java | 4 +++- src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java b/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java index ecac5afcc..eeac5d8ad 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java @@ -48,6 +48,7 @@ public class SFTPEngine protected final PathHelper pathHelper; + private final Session session; protected final Session.Subsystem sub; protected final PacketReader reader; protected final OutputStream out; @@ -63,7 +64,7 @@ public SFTPEngine(SessionFactory ssh) public SFTPEngine(SessionFactory ssh, String pathSep) throws SSHException { - Session session = ssh.startSession(); + session = ssh.startSession(); loggerFactory = session.getLoggerFactory(); log = loggerFactory.getLogger(getClass()); sub = session.startSubsystem("sftp"); @@ -346,6 +347,7 @@ public void close() throws IOException { sub.close(); reader.interrupt(); + session.close(); } protected LoggerFactory getLoggerFactory() { diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java index 821a1c864..a3a175c46 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java @@ -19,6 +19,7 @@ import net.schmizz.sshj.common.LoggerFactory; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.StreamCopier; +import net.schmizz.sshj.connection.channel.direct.Session; import net.schmizz.sshj.connection.channel.direct.Session.Command; import net.schmizz.sshj.connection.channel.direct.SessionFactory; import net.schmizz.sshj.xfer.TransferListener; @@ -41,6 +42,7 @@ class SCPEngine { private final SessionFactory host; private final TransferListener listener; + private Session session; private Command scp; private int exitStatus; @@ -82,7 +84,8 @@ void cleanSlate() { void execSCPWith(ScpCommandLine commandLine) throws SSHException { - scp = host.startSession().exec(commandLine.toCommandLine()); + session = host.startSession(); + scp = session.exec(commandLine.toCommandLine()); } void exit() { @@ -102,6 +105,10 @@ void exit() { log.warn("SCP exit signal: {}", scp.getExitSignal()); } } + if(session != null) { + IOUtils.closeQuietly(session); + session = null; + } scp = null; } From cf340c2a098a253a40b7f0c179c979d802df03e4 Mon Sep 17 00:00:00 2001 From: eshaffer321 Date: Wed, 17 Apr 2024 04:32:46 -0600 Subject: [PATCH 13/13] Update bouncyCastle to 1.78 to mitigate CVE-2024-29857 (#938) Bouncy Caste version before 1.78 have CVE-2024-29857 - Importing an EC certificate with specially crafted F2m parameters can cause high CPU usage during parameter evaluation. Is sshj impacted by this vulnerability? --- build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index b34e80cb7..d131fc2bf 100644 --- a/build.gradle +++ b/build.gradle @@ -41,7 +41,7 @@ compileJava { configurations.implementation.transitive = false -def bouncycastleVersion = "1.75" +def bouncycastleVersion = "1.78" def sshdVersion = "2.10.0" dependencies {