Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Curve25519 Public Key Handling #959

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ public class Curve25519DH extends DHBase {

private static final String ALGORITHM = "X25519";

private static final int ENCODED_ALGORITHM_ID_KEY_LENGTH = 44;
private static final int KEY_LENGTH = 32;

private static final int ALGORITHM_ID_LENGTH = 12;
private int encodedKeyLength;

private static final int KEY_LENGTH = 32;
private int algorithmIdLength;

private final byte[] algorithmId = new byte[ALGORITHM_ID_LENGTH];
// Algorithm Identifier is set on Key Agreement Initialization
private byte[] algorithmId = new byte[KEY_LENGTH];

public Curve25519DH() {
super(ALGORITHM, ALGORITHM);
Expand Down Expand Up @@ -81,23 +82,24 @@ public void init(final AlgorithmParameterSpec params, final Factory<Random> rand
private void setPublicKey(final PublicKey publicKey) {
final byte[] encoded = publicKey.getEncoded();

// Set key and algorithm identifier lengths based on initialized Public Key
encodedKeyLength = encoded.length;
algorithmIdLength = encodedKeyLength - KEY_LENGTH;
algorithmId = new byte[algorithmIdLength];

// Encoded public key consists of the algorithm identifier and public key
if (encoded.length == ENCODED_ALGORITHM_ID_KEY_LENGTH) {
final byte[] publicKeyEncoded = new byte[KEY_LENGTH];
System.arraycopy(encoded, ALGORITHM_ID_LENGTH, publicKeyEncoded, 0, KEY_LENGTH);
setE(publicKeyEncoded);

// Save Algorithm Identifier byte array
System.arraycopy(encoded, 0, algorithmId, 0, ALGORITHM_ID_LENGTH);
} else {
throw new IllegalArgumentException(String.format("X25519 unsupported public key length [%d]", encoded.length));
}
final byte[] publicKeyEncoded = new byte[KEY_LENGTH];
System.arraycopy(encoded, algorithmIdLength, publicKeyEncoded, 0, KEY_LENGTH);
setE(publicKeyEncoded);

// Save Algorithm Identifier byte array
System.arraycopy(encoded, 0, algorithmId, 0, algorithmIdLength);
}

private KeySpec getPeerPublicKeySpec(final byte[] peerPublicKey) {
final byte[] encodedKeySpec = new byte[ENCODED_ALGORITHM_ID_KEY_LENGTH];
System.arraycopy(algorithmId, 0, encodedKeySpec, 0, ALGORITHM_ID_LENGTH);
System.arraycopy(peerPublicKey, 0, encodedKeySpec, ALGORITHM_ID_LENGTH, KEY_LENGTH);
final byte[] encodedKeySpec = new byte[encodedKeyLength];
System.arraycopy(algorithmId, 0, encodedKeySpec, 0, algorithmIdLength);
System.arraycopy(peerPublicKey, 0, encodedKeySpec, algorithmIdLength, KEY_LENGTH);
return new X509EncodedKeySpec(encodedKeySpec);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,24 @@
*/
package net.schmizz.sshj.transport.kex;

import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.random.JCERandom;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.KeyPairGenerator;
import java.security.Provider;
import java.security.Security;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

public class Curve25519DHTest {

private static final String ALGORITHM_FILTER = "KeyPairGenerator.X25519";

private static final int KEY_LENGTH = 32;

private static final byte[] PEER_PUBLIC_KEY = {
Expand All @@ -35,8 +42,16 @@ public class Curve25519DHTest {
1, 2, 3, 4, 5, 6, 7, 8
};

@BeforeEach
public void clearSecurityProvider() {
SecurityUtils.setSecurityProvider(null);
}

@Test
public void testInitPublicKeyLength() throws GeneralSecurityException {
final boolean bouncyCastleRegistrationRequired = isAlgorithmUnsupported();
SecurityUtils.setRegisterBouncyCastle(bouncyCastleRegistrationRequired);

final Curve25519DH dh = new Curve25519DH();
dh.init(null, new JCERandom.Factory());

Expand All @@ -48,6 +63,8 @@ public void testInitPublicKeyLength() throws GeneralSecurityException {

@Test
public void testInitComputeSharedSecretKey() throws GeneralSecurityException {
SecurityUtils.setRegisterBouncyCastle(true);

final Curve25519DH dh = new Curve25519DH();
dh.init(null, new JCERandom.Factory());

Expand All @@ -57,4 +74,9 @@ public void testInitComputeSharedSecretKey() throws GeneralSecurityException {
assertNotNull(sharedSecretKey);
assertEquals(BigInteger.ONE.signum(), sharedSecretKey.signum());
}

private boolean isAlgorithmUnsupported() {
final Provider[] providers = Security.getProviders(ALGORITHM_FILTER);
return providers == null || providers.length == 0;
}
}
Loading