Skip to content

Commit

Permalink
fix: make sure FIFO order for write() when notifyChannelActive(), als…
Browse files Browse the repository at this point in the history
…o make sure channel access thread-safe and avoid potential NPE
  • Loading branch information
okg-cxf committed Jan 18, 2024
1 parent 761d602 commit d23c68f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
72 changes: 34 additions & 38 deletions src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public class DefaultEndpoint implements RedisChannelWriter, Endpoint, PushHandle

private static final AtomicLong ENDPOINT_COUNTER = new AtomicLong();

private static final AtomicIntegerFieldUpdater<DefaultEndpoint> QUEUE_SIZE = AtomicIntegerFieldUpdater
.newUpdater(DefaultEndpoint.class, "queueSize");
private static final AtomicIntegerFieldUpdater<DefaultEndpoint> QUEUE_SIZE = AtomicIntegerFieldUpdater.newUpdater(
DefaultEndpoint.class, "queueSize");

private static final AtomicIntegerFieldUpdater<DefaultEndpoint> STATUS = AtomicIntegerFieldUpdater
.newUpdater(DefaultEndpoint.class, "status");
private static final AtomicIntegerFieldUpdater<DefaultEndpoint> STATUS = AtomicIntegerFieldUpdater.newUpdater(
DefaultEndpoint.class, "status");

private static final int ST_OPEN = 0;

Expand Down Expand Up @@ -191,9 +191,9 @@ public <K, V, T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {
}

if (autoFlushCommands) {

if (isConnected()) {
writeToChannelAndFlush(command);
Channel channel = this.channel;
if (isConnected(channel)) {
writeToChannelAndFlush(channel, command);
} else {
writeToDisconnectedBuffer(command);
}
Expand Down Expand Up @@ -232,9 +232,9 @@ public <K, V, T> RedisCommand<K, V, T> write(RedisCommand<K, V, T> command) {
}

if (autoFlushCommands) {

if (isConnected()) {
writeToChannelAndFlush(commands);
Channel channel = this.channel;
if (isConnected(channel)) {
writeToChannelAndFlush(channel, commands);
} else {
writeToDisconnectedBuffer(commands);
}
Expand Down Expand Up @@ -284,10 +284,9 @@ private RedisException validateWrite(int commands) {
return new RedisException("Connection is closed");
}

final boolean connected = isConnected(this.channel);
if (usesBoundedQueues()) {

boolean connected = isConnected();

if (QUEUE_SIZE.get(this) + commands > clientOptions.getRequestQueueSize()) {
return new RedisException("Request queue size exceeded: " + clientOptions.getRequestQueueSize()
+ ". Commands are not accepted until the queue size drops.");
Expand All @@ -304,7 +303,7 @@ private RedisException validateWrite(int commands) {
}
}

if (!isConnected() && rejectCommandsWhileDisconnected) {
if (!connected && rejectCommandsWhileDisconnected) {
return new RedisException("Currently not connected. Commands are rejected.");
}

Expand Down Expand Up @@ -366,11 +365,11 @@ private void writeToDisconnectedBuffer(RedisCommand<?, ?, ?> command) {
commandBuffer.add(command);
}

private void writeToChannelAndFlush(RedisCommand<?, ?, ?> command) {
private void writeToChannelAndFlush(Channel channel, RedisCommand<?, ?, ?> command) {

QUEUE_SIZE.incrementAndGet(this);

ChannelFuture channelFuture = channelWriteAndFlush(command);
ChannelFuture channelFuture = channelWriteAndFlush(channel, command);

if (reliability == Reliability.AT_MOST_ONCE) {
// cancel on exceptions and remove from queue, because there is no housekeeping
Expand All @@ -383,30 +382,30 @@ private void writeToChannelAndFlush(RedisCommand<?, ?, ?> command) {
}
}

private void writeToChannelAndFlush(Collection<? extends RedisCommand<?, ?, ?>> commands) {
private void writeToChannelAndFlush(Channel channel, Collection<? extends RedisCommand<?, ?, ?>> commands) {

QUEUE_SIZE.addAndGet(this, commands.size());

if (reliability == Reliability.AT_MOST_ONCE) {

// cancel on exceptions and remove from queue, because there is no housekeeping
for (RedisCommand<?, ?, ?> command : commands) {
channelWrite(command).addListener(AtMostOnceWriteListener.newInstance(this, command));
channelWrite(channel, command).addListener(AtMostOnceWriteListener.newInstance(this, command));
}
}

if (reliability == Reliability.AT_LEAST_ONCE) {

// commands are ok to stay within the queue, reconnect will retrigger them
for (RedisCommand<?, ?, ?> command : commands) {
channelWrite(command).addListener(RetryListener.newInstance(this, command));
channelWrite(channel, command).addListener(RetryListener.newInstance(this, command));
}
}

channelFlush();
channelFlush(channel);
}

private void channelFlush() {
private void channelFlush(Channel channel) {

if (debugEnabled) {
logger.debug("{} write() channelFlush", logPrefix());
Expand All @@ -415,7 +414,7 @@ private void channelFlush() {
channel.flush();
}

private ChannelFuture channelWrite(RedisCommand<?, ?, ?> command) {
private ChannelFuture channelWrite(Channel channel, RedisCommand<?, ?, ?> command) {

if (debugEnabled) {
logger.debug("{} write() channelWrite command {}", logPrefix(), command);
Expand All @@ -424,7 +423,7 @@ private ChannelFuture channelWrite(RedisCommand<?, ?, ?> command) {
return channel.write(command);
}

private ChannelFuture channelWriteAndFlush(RedisCommand<?, ?, ?> command) {
private ChannelFuture channelWriteAndFlush(Channel channel, RedisCommand<?, ?, ?> command) {

if (debugEnabled) {
logger.debug("{} write() writeAndFlush command {}", logPrefix(), command);
Expand All @@ -437,7 +436,6 @@ private ChannelFuture channelWriteAndFlush(RedisCommand<?, ?, ?> command) {
public void notifyChannelActive(Channel channel) {

this.logPrefix = null;
this.channel = channel;
this.connectionError = null;

if (isClosed()) {
Expand All @@ -452,6 +450,7 @@ public void notifyChannelActive(Channel channel) {
}

sharedLock.doExclusive(() -> {
this.channel = channel;

try {
// Move queued commands to buffer before issuing any commands because of connection activation.
Expand All @@ -474,7 +473,7 @@ public void notifyChannelActive(Channel channel) {
inActivation = false;
}

flushCommands(disconnectedBuffer);
flushCommands(channel, disconnectedBuffer);
} catch (Exception e) {

if (debugEnabled) {
Expand Down Expand Up @@ -527,7 +526,7 @@ public void notifyException(Throwable t) {
doExclusive(this::drainCommands).forEach(cmd -> cmd.completeExceptionally(t));
}

if (!isConnected()) {
if (!isConnected(this.channel)) {
connectionError = t;
}
}
Expand All @@ -540,16 +539,16 @@ public void registerConnectionWatchdog(ConnectionWatchdog connectionWatchdog) {
@Override
@SuppressWarnings({ "rawtypes", "unchecked" })
public void flushCommands() {
flushCommands(commandBuffer);
flushCommands(this.channel, commandBuffer);
}

private void flushCommands(Queue<RedisCommand<?, ?, ?>> queue) {
private void flushCommands(Channel channel, Queue<RedisCommand<?, ?, ?>> queue) {

if (debugEnabled) {
logger.debug("{} flushCommands()", logPrefix());
}

if (isConnected()) {
if (isConnected(channel)) {

List<RedisCommand<?, ?, ?>> commands = sharedLock.doExclusive(() -> {

Expand All @@ -565,7 +564,7 @@ private void flushCommands(Queue<RedisCommand<?, ?, ?>> queue) {
}

if (!commands.isEmpty()) {
writeToChannelAndFlush(commands);
writeToChannelAndFlush(channel, commands);
}
}
}
Expand Down Expand Up @@ -628,10 +627,10 @@ public void disconnect() {

private Channel getOpenChannel() {

Channel currentChannel = this.channel;
Channel channel = this.channel;

if (currentChannel != null) {
return currentChannel;
if (channel != null /* && channel.isOpen() is this deliberately omitted? */) {
return channel;
}

return null;
Expand All @@ -648,6 +647,7 @@ public void reset() {
logger.debug("{} reset()", logPrefix());
}

Channel channel = this.channel;
if (channel != null) {
channel.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset());
}
Expand Down Expand Up @@ -720,9 +720,7 @@ public void notifyDrainQueuedCommands(HasQueuedCommands queuedCommands) {
}
}

if (isConnected()) {
flushCommands(disconnectedBuffer);
}
flushCommands(this.channel, disconnectedBuffer);
});
}

Expand Down Expand Up @@ -787,9 +785,7 @@ private void cancelCommands(String message, Iterable<? extends RedisCommand<?, ?
}
}

private boolean isConnected() {

Channel channel = this.channel;
private boolean isConnected(Channel channel) {
return channel != null && channel.isActive();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,34 @@ void before() {
sut.setConnectionFacade(connectionFacade);
}

@Test
void writeShouldGuaranteeFIFOOrder() {
sut.write(Collections.singletonList(new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8))));

sut.registerConnectionWatchdog(connectionWatchdog);
doAnswer(i -> sut.write(new Command<>(CommandType.AUTH, new StatusOutput<>(StringCodec.UTF8)))).when(connectionWatchdog)
.arm();
when(channel.isActive()).thenReturn(true);

sut.notifyChannelActive(channel);

DefaultChannelPromise promise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);

when(channel.writeAndFlush(any())).thenAnswer(invocation -> {
if (invocation.getArguments()[0] instanceof RedisCommand) {
queue.add((RedisCommand) invocation.getArguments()[0]);
}

if (invocation.getArguments()[0] instanceof Collection) {
queue.addAll((Collection) invocation.getArguments()[0]);
}
return promise;
});

assertThat(queue).hasSize(2).first().hasFieldOrPropertyWithValue("type", CommandType.SELECT);
assertThat(queue).hasSize(2).last().hasFieldOrPropertyWithValue("type", CommandType.AUTH);
}

@Test
void writeConnectedShouldWriteCommandToChannel() {

Expand Down Expand Up @@ -396,11 +424,9 @@ void shouldNotReplayActivationCommands() {

when(channel.isActive()).thenReturn(true);
ConnectionTestUtil.getDisconnectedBuffer(sut)
.add(new ActivationCommand<>(
new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8))));
.add(new ActivationCommand<>(new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8))));
ConnectionTestUtil.getDisconnectedBuffer(sut).add(new LatencyMeteredCommand<>(
new ActivationCommand<>(
new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8)))));
new ActivationCommand<>(new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8)))));

doAnswer(i -> {

Expand Down

0 comments on commit d23c68f

Please sign in to comment.