Skip to content

Commit

Permalink
Modbus/TCP Security (#97)
Browse files Browse the repository at this point in the history
Adds support for Modbus/TCP Security (Modbus over TLS).

See https://modbus.org/docs/MB-TCP-Security-v21_2018-07-24.pdf
  • Loading branch information
kevinherron authored Dec 27, 2024
1 parent a9a0fb1 commit 27aa926
Show file tree
Hide file tree
Showing 26 changed files with 2,097 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.digitalpetri.modbus.internal.util.ExecutionQueue;
import com.digitalpetri.modbus.serial.SerialPortTransportConfig;
import com.digitalpetri.modbus.serial.SerialPortTransportConfig.Builder;
import com.digitalpetri.modbus.server.ModbusRequestContext.ModbusRtuRequestContext;
import com.digitalpetri.modbus.server.ModbusRtuServerTransport;
import com.fazecast.jSerialComm.SerialPort;
import com.fazecast.jSerialComm.SerialPortDataListener;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

/**
* Configuration for a {@link NettyTcpClientTransport}.
Expand All @@ -23,6 +26,9 @@
* {@link Bootstrap}.
* @param pipelineCustomizer a {@link Consumer} that can be used to customize the Netty
* {@link ChannelPipeline}.
* @param tlsEnabled whether to enable TLS (Modbus/TCP Security).
* @param keyManagerFactory the {@link KeyManagerFactory} to use if TLS is enabled.
* @param trustManagerFactory the {@link TrustManagerFactory} to use if TLS is enabled.
*/
public record NettyClientTransportConfig(
String hostname,
Expand All @@ -33,9 +39,12 @@ public record NettyClientTransportConfig(
EventLoopGroup eventLoopGroup,
ExecutorService executor,
Consumer<Bootstrap> bootstrapCustomizer,
Consumer<ChannelPipeline> pipelineCustomizer
Consumer<ChannelPipeline> pipelineCustomizer,
boolean tlsEnabled,
Optional<KeyManagerFactory> keyManagerFactory,
Optional<TrustManagerFactory> trustManagerFactory
) {

/**
* Create a new {@link NettyClientTransportConfig} with a callback that allows customizing the
* configuration.
Expand All @@ -59,7 +68,7 @@ public static class Builder {
/**
* The port to connect to.
*/
public int port = 502;
public int port = -1;

/**
* The connect timeout.
Expand Down Expand Up @@ -100,16 +109,42 @@ public static class Builder {
*/
public Consumer<ChannelPipeline> pipelineCustomizer = p -> {};

/**
* Whether to enable TLS (Modbus/TCP Security).
*/
public boolean tlsEnabled = false;

/**
* The {@link KeyManagerFactory} to use if TLS is enabled.
*/
public KeyManagerFactory keyManagerFactory = null;

/**
* The {@link TrustManagerFactory} to use if TLS is enabled.
*/
public TrustManagerFactory trustManagerFactory = null;

public NettyClientTransportConfig build() {
if (hostname == null) {
throw new NullPointerException("hostname must not be null");
}
if (port == -1) {
port = tlsEnabled ? 802 : 502;
}
if (eventLoopGroup == null) {
eventLoopGroup = Netty.sharedEventLoop();
}
if (executor == null) {
executor = Modbus.sharedExecutor();
}
if (tlsEnabled) {
if (keyManagerFactory == null) {
throw new NullPointerException("keyManagerFactory must not be null");
}
if (trustManagerFactory == null) {
throw new NullPointerException("trustManagerFactory must not be null");
}
}

return new NettyClientTransportConfig(
hostname,
Expand All @@ -120,7 +155,10 @@ public NettyClientTransportConfig build() {
eventLoopGroup,
executor,
bootstrapCustomizer,
pipelineCustomizer
pipelineCustomizer,
tlsEnabled,
Optional.ofNullable(keyManagerFactory),
Optional.ofNullable(trustManagerFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProtocols;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -158,7 +162,20 @@ public CompletableFuture<Channel> connect(FsmContext<State, Event> fsmContext) {
.option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel channel) {
protected void initChannel(SocketChannel channel) throws Exception {
if (config.tlsEnabled()) {
SslContext sslContext = SslContextBuilder.forClient()
.clientAuth(ClientAuth.REQUIRE)
.keyManager(config.keyManagerFactory().orElseThrow())
.trustManager(config.trustManagerFactory().orElseThrow())
.protocols(SslProtocols.TLS_v1_2, SslProtocols.TLS_v1_3)
.build();

channel.pipeline().addLast(
sslContext.newHandler(channel.alloc(), config.hostname(), config.port())
);
}

channel.pipeline().addLast(new ModbusRtuClientFrameReceiver());

config.pipelineCustomizer().accept(channel.pipeline());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProtocols;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
Expand Down Expand Up @@ -163,6 +168,13 @@ protected void channelRead0(ChannelHandlerContext ctx, ModbusTcpFrame frame) {
executionQueue.submit(() -> frameReceiver.accept(frame));
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.error("Exception caught", cause);
ctx.close();
}

}

private class ModbusTcpChannelActions implements ChannelActions {
Expand All @@ -174,15 +186,7 @@ public CompletableFuture<Channel> connect(FsmContext<State, Event> fsmContext) {
.group(config.eventLoopGroup())
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.connectTimeout().toMillis())
.option(ChannelOption.TCP_NODELAY, Boolean.TRUE)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel channel) {
channel.pipeline().addLast(new ModbusTcpCodec());
channel.pipeline().addLast(new ModbusTcpFrameHandler());

config.pipelineCustomizer().accept(channel.pipeline());
}
});
.handler(newChannelInitializer());

config.bootstrapCustomizer().accept(bootstrap);

Expand All @@ -191,7 +195,21 @@ protected void initChannel(SocketChannel channel) {
bootstrap.connect(config.hostname(), config.port()).addListener(
(ChannelFutureListener) channelFuture -> {
if (channelFuture.isSuccess()) {
future.complete(channelFuture.channel());
Channel channel = channelFuture.channel();

if (config.tlsEnabled()) {
channel.pipeline().get(SslHandler.class).handshakeFuture().addListener(
handshakeFuture -> {
if (handshakeFuture.isSuccess()) {
future.complete(channel);
} else {
future.completeExceptionally(handshakeFuture.cause());
}
}
);
} else {
future.complete(channel);
}
} else {
future.completeExceptionally(channelFuture.cause());
}
Expand All @@ -201,6 +219,31 @@ protected void initChannel(SocketChannel channel) {
return future;
}

private ChannelInitializer<SocketChannel> newChannelInitializer() {
return new ChannelInitializer<>() {
@Override
protected void initChannel(SocketChannel channel) throws Exception {
if (config.tlsEnabled()) {
SslContext sslContext = SslContextBuilder.forClient()
.clientAuth(ClientAuth.REQUIRE)
.keyManager(config.keyManagerFactory().orElseThrow())
.trustManager(config.trustManagerFactory().orElseThrow())
.protocols(SslProtocols.TLS_v1_2, SslProtocols.TLS_v1_3)
.build();

channel.pipeline().addLast(
sslContext.newHandler(channel.alloc(), config.hostname(), config.port())
);
}

channel.pipeline().addLast(new ModbusTcpCodec());
channel.pipeline().addLast(new ModbusTcpFrameHandler());

config.pipelineCustomizer().accept(channel.pipeline());
}
};
}

@Override
public CompletableFuture<Void> disconnect(
FsmContext<State, Event> fsmContext, Channel channel) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.digitalpetri.modbus.tcp.security;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

public class SecurityUtil {

/**
* Create a {@link KeyManagerFactory} from a private key and certificates.
*
* @param privateKey the private key.
* @param certificates the certificates.
* @return a {@link KeyManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
* @throws IOException if an error occurs.
*/
public static KeyManagerFactory createKeyManagerFactory(
PrivateKey privateKey,
X509Certificate... certificates
) throws GeneralSecurityException, IOException {

KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);

keyStore.setKeyEntry("key", privateKey, new char[0], certificates);

return createKeyManagerFactory(keyStore, new char[0]);
}

/**
* Create a {@link KeyManagerFactory} from a {@link KeyStore}.
*
* @param keyStore the {@link KeyStore}.
* @param keyStorePassword the password for the {@link KeyStore}.
* @return a {@link KeyManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
*/
public static KeyManagerFactory createKeyManagerFactory(
KeyStore keyStore,
char[] keyStorePassword
) throws GeneralSecurityException {

KeyManagerFactory keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());

keyManagerFactory.init(keyStore, keyStorePassword);

return keyManagerFactory;
}

/**
* Create a {@link TrustManagerFactory} from certificates.
*
* @param certificates the certificates.
* @return a {@link TrustManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
* @throws IOException if an error occurs.
*/
public static TrustManagerFactory createTrustManagerFactory(
X509Certificate... certificates
) throws GeneralSecurityException, IOException {

KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
keyStore.load(null, null);

for (int i = 0; i < certificates.length; i++) {
keyStore.setCertificateEntry("cert" + i, certificates[i]);
}

return createTrustManagerFactory(keyStore);
}

/**
* Create a {@link TrustManagerFactory} from a {@link KeyStore}.
*
* @param keyStore the {@link KeyStore}.
* @return a {@link TrustManagerFactory}.
* @throws GeneralSecurityException if an error occurs.
*/
public static TrustManagerFactory createTrustManagerFactory(
KeyStore keyStore
) throws GeneralSecurityException {

TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());

trustManagerFactory.init(keyStore);

return trustManagerFactory;
}

}
Loading

0 comments on commit 27aa926

Please sign in to comment.