From 0ea79326253cee081867ef461be3c76d82ad30ef Mon Sep 17 00:00:00 2001 From: Martin Volf Date: Fri, 19 Jan 2024 21:47:12 +0100 Subject: [PATCH] connected sockets can be passed to the library fixes hierynomus/sshj#924 Signed-off-by: Martin Volf --- .../java/net/schmizz/sshj/SocketClient.java | 20 +++-- .../net/schmizz/sshj/ConnectedSocketTest.java | 78 +++++++++++++++++++ 2 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 src/test/java/net/schmizz/sshj/ConnectedSocketTest.java diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index d7971243..5447b557 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -65,7 +65,9 @@ public void connect(String hostname, int port) throws IOException { this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -77,8 +79,10 @@ public void connect(String hostname, int port, InetAddress localAddr, int localP this.hostname = hostname; this.port = port; socket = socketFactory.createSocket(); - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + if (! socket.isConnected()) { + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(makeInetSocketAddress(hostname, port), connectTimeout); + } onConnect(); } } @@ -104,7 +108,9 @@ public void connect(InetAddress host) throws IOException { public void connect(InetAddress host, int port) throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } @@ -112,8 +118,10 @@ public void connect(InetAddress host, int port, InetAddress localAddr, int local throws IOException { this.port = port; socket = socketFactory.createSocket(); - socket.bind(new InetSocketAddress(localAddr, localPort)); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + if (! socket.isConnected()) { + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(new InetSocketAddress(host, port), connectTimeout); + } onConnect(); } diff --git a/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java new file mode 100644 index 00000000..0efe86c9 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/ConnectedSocketTest.java @@ -0,0 +1,78 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.sshj.SSHClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import javax.net.SocketFactory; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + + +public class ConnectedSocketTest { + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + @BeforeEach + public void setupClient() throws IOException { + SSHClient defaultClient = fixture.setupDefaultClient(); + } + + @Test + public void connectsIfUnconnected() { + assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + } + + @Test + public void handlesConnected() throws IOException { + Socket socket = SocketFactory.getDefault().createSocket(); + SocketFactory factory = new SocketFactory() { + @Override + public Socket createSocket() { + return socket; + } + @Override + public Socket createSocket(InetAddress host, int port) { + return socket; + } + @Override + public Socket createSocket(InetAddress address, int port, + InetAddress localAddress, int localPort) { + return socket; + } + @Override + public Socket createSocket(String host, int port) { + return socket; + } + @Override + public Socket createSocket(String host, int port, + InetAddress localHost, int localPort) { + return socket; + } + }; + socket.connect(new InetSocketAddress("localhost", fixture.getServer().getPort())); + fixture.getClient().setSocketFactory(factory); + assertDoesNotThrow(() -> fixture.connectClient(fixture.getClient())); + } +}