Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
hierynomus authored Apr 15, 2024
2 parents 699f6b1 + 624fe83 commit 90df54c
Show file tree
Hide file tree
Showing 21 changed files with 703 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
Expand Down Expand Up @@ -62,14 +73,27 @@ private void setUpLogger(String className) {
watchedLoggers.add(logger);
}

@Test
void strictKeyExchange() throws Throwable {
try (SSHClient client = sshd.getConnectedClient()) {
private static Stream<Arguments> strictKeyExchange() {
Config defaultConfig = new DefaultConfig();
Config heartbeaterConfig = new DefaultConfig();
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new HotLoopHeartbeater(connection);
}
});
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
}

@MethodSource
@ParameterizedTest
void strictKeyExchange(Config config) throws Throwable {
try (SSHClient client = sshd.getConnectedClient(config)) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated());
}
List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence(
assertThat(keyExchangerLogs).contains(
"Initiating key exchange",
"Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT",
Expand All @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable {
List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(decoderLogs).containsExactly(
assertThat(decoderLogs).startsWith(
"Received packet #0",
"Received packet #1",
"Received packet #2",
Expand All @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable {
List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(encoderLogs).containsExactly(
assertThat(encoderLogs).startsWith(
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
Expand All @@ -108,4 +132,22 @@ private List<String> getLogs(String className) {
.collect(Collectors.toList());
}

private static class HotLoopHeartbeater extends KeepAlive {

HotLoopHeartbeater(ConnectionImpl conn) {
super(conn, "sshj-Heartbeater");
}

@Override
public boolean isEnabled() {
return true;
}

@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.sftp;

import com.hierynomus.sshj.sftp.RemoteResourceSelector.Result;
import net.schmizz.sshj.sftp.RemoteResourceFilter;

public class RemoteResourceFilterConverter {

public static RemoteResourceSelector selectorFrom(RemoteResourceFilter filter) {
if (filter == null) {
return RemoteResourceSelector.ALL;
}

return resource -> filter.accept(resource) ? Result.ACCEPT : Result.CONTINUE;
}
}
49 changes: 49 additions & 0 deletions src/main/java/com/hierynomus/sshj/sftp/RemoteResourceSelector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.sftp;

import net.schmizz.sshj.sftp.RemoteResourceInfo;

public interface RemoteResourceSelector {
public static RemoteResourceSelector ALL = new RemoteResourceSelector() {
@Override
public Result select(RemoteResourceInfo resource) {
return Result.ACCEPT;
}
};

enum Result {
/**
* Accept the remote resource and add it to the result.
*/
ACCEPT,

/**
* Do not add the remote resource to the result and continue with the next.
*/
CONTINUE,

/**
* Do not add the remote resource to the result and stop further execution.
*/
BREAK;
}

/**
* Decide whether the remote resource should be included in the result and whether execution should continue.
*/
Result select(RemoteResourceInfo resource);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.transport.verification;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.transport.mac.MAC;
Expand All @@ -26,9 +28,13 @@
import java.util.regex.Pattern;

import com.hierynomus.sshj.transport.mac.Macs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KnownHostMatchers {

private static final Logger log = LoggerFactory.getLogger(KnownHostMatchers.class);

public static HostMatcher createMatcher(String hostEntry) throws SSHException {
if (hostEntry.contains(",")) {
return new AnyHostMatcher(hostEntry);
Expand Down Expand Up @@ -80,17 +86,22 @@ private static class HashedHostMatcher implements HostMatcher {

@Override
public boolean match(String hostname) throws IOException {
return hash.equals(hashHost(hostname));
try {
return hash.equals(hashHost(hostname));
} catch (Base64DecodingException err) {
log.warn("Hostname [{}] not matched: salt decoding failed", hostname, err);
return false;
}
}

private String hashHost(String host) throws IOException {
private String hashHost(String host) throws IOException, Base64DecodingException {
sha1.init(getSaltyBytes());
return "|1|" + salt + "|" + Base64.getEncoder().encodeToString(sha1.doFinal(host.getBytes(IOUtils.UTF8)));
}

private byte[] getSaltyBytes() {
private byte[] getSaltyBytes() throws IOException, Base64DecodingException {
if (saltyBytes == null) {
saltyBytes = Base64.getDecoder().decode(salt);
saltyBytes = Base64Decoder.decode(salt);
}
return saltyBytes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.userauth.keyprovider;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.KeyType;

Expand All @@ -23,7 +25,6 @@
import java.io.IOException;
import java.io.Reader;
import java.security.PublicKey;
import java.util.Base64;

public class OpenSSHKeyFileUtil {
private OpenSSHKeyFileUtil() {
Expand Down Expand Up @@ -54,16 +55,19 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException {
if (!keydata.isEmpty()) {
String[] parts = keydata.trim().split("\\s+");
if (parts.length >= 2) {
byte[] decodedPublicKey = Base64Decoder.decode(parts[1]);
return new ParsedPubKey(
KeyType.fromString(parts[0]),
new Buffer.PlainBuffer(Base64.getDecoder().decode(parts[1])).readPublicKey()
new Buffer.PlainBuffer(decodedPublicKey).readPublicKey()
);
} else {
throw new IOException("Got line with only one column");
}
}
}
throw new IOException("Public key file is blank");
} catch (Base64DecodingException err) {
throw new IOException("Public key decoding failed", err);
} finally {
br.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,8 @@
import net.i2p.crypto.eddsa.EdDSAPrivateKey;
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable;
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.common.Buffer.PlainBuffer;
import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SSHRuntimeException;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider;
import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider;
Expand All @@ -55,7 +50,6 @@
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -124,7 +118,7 @@ protected KeyPair readKeyPair() throws IOException {
try {
if (checkHeader(reader)) {
final String encodedPrivateKey = readEncodedKey(reader);
byte[] decodedPrivateKey = Base64.getDecoder().decode(encodedPrivateKey);
byte[] decodedPrivateKey = Base64Decoder.decode(encodedPrivateKey);
final PlainBuffer bufferedPrivateKey = new PlainBuffer(decodedPrivateKey);
return readDecodedKeyPair(bufferedPrivateKey);
} else {
Expand All @@ -133,6 +127,8 @@ protected KeyPair readKeyPair() throws IOException {
}
} catch (final GeneralSecurityException e) {
throw new SSHRuntimeException("Read OpenSSH Version 1 Key failed", e);
} catch (Base64DecodingException e) {
throw new SSHRuntimeException("Private Key decoding failed", e);
} finally {
IOUtils.closeQuietly(reader);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -804,12 +804,12 @@ protected void onConnect()
throws IOException {
super.onConnect();
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
doKex();
final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) {
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
}
doKex();
}

/**
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/net/schmizz/sshj/SocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public void connect(String hostname, int port) throws IOException {
this.hostname = hostname;
this.port = port;
socket = socketFactory.createSocket();
socket.connect(makeInetSocketAddress(hostname, port), connectTimeout);
if (! socket.isConnected()) {
socket.connect(makeInetSocketAddress(hostname, port), connectTimeout);
}
onConnect();
}
}
Expand Down Expand Up @@ -104,7 +106,9 @@ public void connect(InetAddress host) throws IOException {
public void connect(InetAddress host, int port) throws IOException {
this.port = port;
socket = socketFactory.createSocket();
socket.connect(new InetSocketAddress(host, port), connectTimeout);
if (! socket.isConnected()) {
socket.connect(new InetSocketAddress(host, port), connectTimeout);
}
onConnect();
}

Expand Down
47 changes: 47 additions & 0 deletions src/main/java/net/schmizz/sshj/common/Base64Decoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.schmizz.sshj.common;

import java.io.IOException;
import java.util.Base64;

/**
* <p>Wraps {@link java.util.Base64.Decoder} in order to wrap unchecked {@code IllegalArgumentException} thrown by
* the default Java Base64 decoder here and there.</p>
*
* <p>Please use this class instead of {@link java.util.Base64.Decoder}.</p>
*/
public class Base64Decoder {
private Base64Decoder() {
}

public static byte[] decode(byte[] source) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(source);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}

public static byte[] decode(String src) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(src);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}
}
Loading

0 comments on commit 90df54c

Please sign in to comment.