Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Remote port forwarding buffers can grow without limits (issue #658) #913

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/net/schmizz/sshj/Config.java
Original file line number Diff line number Diff line change
@@ -200,4 +200,8 @@ public interface Config {
* See {@link #isVerifyHostKeyCertificates()}.
*/
void setVerifyHostKeyCertificates(boolean value);

int getMaxCircularBufferSize();

void setMaxCircularBufferSize(int maxCircularBufferSize);
}
12 changes: 12 additions & 0 deletions src/main/java/net/schmizz/sshj/ConfigImpl.java
Original file line number Diff line number Diff line change
@@ -49,6 +49,8 @@ public class ConfigImpl
private boolean waitForServerIdentBeforeSendingClientIdent = false;
private LoggerFactory loggerFactory;
private boolean verifyHostKeyCertificates = true;
// HF-982: default to 16MB buffers.
private int maxCircularBufferSize = 16 * 1024 * 1024;

@Override
public List<Factory.Named<Cipher>> getCipherFactories() {
@@ -175,6 +177,16 @@ public LoggerFactory getLoggerFactory() {
return loggerFactory;
}

@Override
public int getMaxCircularBufferSize() {
return maxCircularBufferSize;
}

@Override
public void setMaxCircularBufferSize(int maxCircularBufferSize) {
this.maxCircularBufferSize = maxCircularBufferSize;
}

@Override
public void setLoggerFactory(LoggerFactory loggerFactory) {
this.loggerFactory = loggerFactory;
194 changes: 194 additions & 0 deletions src/main/java/net/schmizz/sshj/common/CircularBuffer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.common;

public class CircularBuffer<T extends CircularBuffer<T>> {

public static class CircularBufferException
extends SSHException {

public CircularBufferException(String message) {
super(message);
}
}

public static final class PlainCircularBuffer
extends CircularBuffer<PlainCircularBuffer> {

public PlainCircularBuffer(int size, int maxSize) {
super(size, maxSize);
}
}

/**
* Maximum size of the internal array (one plus the maximum capacity of the buffer).
*/
private final int maxSize;
/**
* Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos.
*/
private byte[] data;
/**
* Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty.
* Can take the value data.length, which is equivalent to 0.
*/
private int rpos;
/**
* Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is
* empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to
* data.length - 1 bytes. Can take the value data.length, which is equivalent to 0.
*/
private int wpos;

/**
* Determines the size to which to grow the internal array.
*/
private int getNextSize(int currentSize) {
// Use next power of 2.
int nextSize = 1;
while (nextSize < currentSize) {
nextSize <<= 1;
if (nextSize <= 0) {
return maxSize;
}
}
return Math.min(nextSize, maxSize); // limit to max size
}

/**
* Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/
*/
public CircularBuffer(int size, int maxSize) {
this.maxSize = maxSize;
if (size > maxSize) {
throw new IllegalArgumentException(
String.format("Initial requested size %d larger than maximum size %d", size, maxSize));
}
int initialSize = getNextSize(size);
this.data = new byte[initialSize];
this.rpos = 0;
this.wpos = 0;
}

/**
* Data available in the buffer for reading.
*/
public int available() {
int available = wpos - rpos;
return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos
}

private void ensureAvailable(int a)
throws CircularBufferException {
if (available() < a) {
throw new CircularBufferException("Underflow");
}
}

/**
* Returns how many more bytes this buffer can receive.
*/
public int maxPossibleRemainingCapacity() {
// Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left.
int remaining = rpos - wpos - 1;
if (remaining < 0) {
remaining += data.length; // adjust if rpos is left of wpos
}
// Add the maximum amount the internal array can grow.
return remaining + maxSize - data.length;
}

/**
* If the internal array does not have room for "capacity" more bytes, resizes the array to make that room.
*/
void ensureCapacity(int capacity) throws CircularBufferException {
int available = available();
int remaining = data.length - available;
// If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left.
if (remaining <= capacity) {
int neededSize = available + capacity + 1;
int nextSize = getNextSize(neededSize);
if (nextSize < neededSize) {
throw new CircularBufferException("Attempted overflow");
}
byte[] tmp = new byte[nextSize];
// Copy data to the beginning of the new array.
if (wpos >= rpos) {
System.arraycopy(data, rpos, tmp, 0, available);
wpos -= rpos; // wpos must be relative to the new rpos, which will be 0
} else {
int tail = data.length - rpos;
System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos
System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos
wpos += tail; // wpos must be relative to the new rpos, which will be 0
}
rpos = 0;
data = tmp;
}
}

/**
* Reads data from this buffer into the provided array.
*/
public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException {
ensureAvailable(length);

int rposNext = rpos + length;
if (rposNext <= data.length) {
System.arraycopy(data, rpos, destination, offset, length);
} else {
int tail = data.length - rpos;
System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos
rposNext = length - tail; // rpos wraps around the end of the buffer
System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder
}
// This can make rpos equal data.length, which has the same effect as wpos being 0.
rpos = rposNext;
}

/**
* Writes data to this buffer from the provided array.
*/
@SuppressWarnings("unchecked")
public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException {
ensureCapacity(length);

int wposNext = wpos + length;
if (wposNext <= data.length) {
System.arraycopy(source, offset, data, wpos, length);
} else {
int tail = data.length - wpos;
System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos
wposNext = length - tail; // wpos wraps around the end of the buffer
System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder
}
// This can make wpos equal data.length, which has the same effect as wpos being 0.
wpos = wposNext;

return (T) this;
}

// Used only for testing.
int length() {
return data.length;
}

@Override
public String toString() {
return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
}

}
Original file line number Diff line number Diff line change
@@ -164,8 +164,7 @@ public String getType() {
}

@Override
public void handle(Message msg, SSHPacket buf)
throws ConnectionException, TransportException {
public void handle(Message msg, SSHPacket buf) throws SSHException {
switch (msg) {

case CHANNEL_DATA:
@@ -354,7 +353,7 @@ protected void finishOff() {
}

protected void gotExtendedData(SSHPacket buf)
throws ConnectionException, TransportException {
throws SSHException {
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
"Extended data not supported on " + type + " channel");
}
@@ -375,7 +374,7 @@ protected SSHPacket newBuffer(Message cmd) {
}

protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
throws ConnectionException, TransportException {
throws SSHException {
final int len;
try {
len = buf.readUInt32AsInt();
Original file line number Diff line number Diff line change
@@ -38,18 +38,19 @@ public final class ChannelInputStream
private final Channel chan;
private final Transport trans;
private final Window.Local win;
private final Buffer.PlainBuffer buf;
private final CircularBuffer.PlainCircularBuffer buf;
private final byte[] b = new byte[1];

private boolean eof;
private SSHException error;

public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
this.chan = chan;
log = chan.getLoggerFactory().getLogger(getClass());
this.log = chan.getLoggerFactory().getLogger(getClass());
this.trans = trans;
this.win = win;
buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize());
this.buf = new CircularBuffer.PlainCircularBuffer(
chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize());
}

@Override
@@ -113,48 +114,44 @@ public int read(byte[] b, int off, int len)
len = buf.available();
}
buf.readRawBytes(b, off, len);
if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) {
buf.clear();
}
}

if (!chan.getAutoExpand()) {
checkWindow();
if (!chan.getAutoExpand()) {
checkWindow();
}
}

return len;
}

public void receive(byte[] data, int offset, int len)
throws ConnectionException, TransportException {
public void receive(byte[] data, int offset, int len) throws SSHException {
if (eof) {
throw new ConnectionException("Getting data on EOF'ed stream");
}
synchronized (buf) {
buf.putRawBytes(data, offset, len);
buf.notifyAll();
}
// Potential fix for #203 (window consumed below 0).
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
// And the window has not expanded yet.
synchronized (win) {
// Potential fix for #203 (window consumed below 0).
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
// And the window has not expanded yet.
win.consume(len);
}
if (chan.getAutoExpand()) {
checkWindow();
if (chan.getAutoExpand()) {
checkWindow();
}
}
}

private void checkWindow()
throws TransportException {
synchronized (win) {
final long adjustment = win.neededAdjustment();
if (adjustment > 0) {
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
win.expand(adjustment);
}
private void checkWindow() throws TransportException {
/*
* Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The
* difference between that and the remaining capacity is the maximum adjustment we can make to the window.
*/
final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize();
final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment);
if (adjustment > 0) {
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
win.expand(adjustment);
}
}

Original file line number Diff line number Diff line change
@@ -210,7 +210,7 @@ protected void eofInputStreams() {

@Override
protected void gotExtendedData(SSHPacket buf)
throws ConnectionException, TransportException {
throws SSHException {
try {
final int dataTypeCode = buf.readUInt32AsInt();
if (dataTypeCode == 1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.connection.channel.forwarded;

import static org.junit.jupiter.api.Assertions.*;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward;
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RemotePFPerformanceTest {

private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class);

@Test
@Disabled
public void startPF() throws IOException, InterruptedException {
DefaultConfig config = new DefaultConfig();
config.setMaxCircularBufferSize(16 * 1024 * 1024);
SSHClient client = new SSHClient(config);
client.loadKnownHosts();
client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba");

client.getConnection().getKeepAlive().setKeepAliveInterval(5);
client.connect("localhost");
client.getConnection().getKeepAlive().setKeepAliveInterval(5);

Object consumerReadyMonitor = new Object();
ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor);
ProducerThread producerThread = new ProducerThread();
try {

client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD"));

/*
* We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further
* forward all such channels to google.com:80
*/
client.getRemotePortForwarder().bind(
// where the server should listen
new Forward(8888),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345)));

consumerThread.start();
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.wait();
}
producerThread.start();

// Wait for consumer to finish receiving data.
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.wait();
}

} finally {
producerThread.interrupt();
consumerThread.interrupt();
client.disconnect();
}
}

private static class ConsumerThread extends Thread {
private final Object consumerReadyMonitor;

private ConsumerThread(Object consumerReadyMonitor) {
super("Consumer");
this.consumerReadyMonitor = consumerReadyMonitor;
}

@Override
public void run() {
try (ServerSocket serverSocket = new ServerSocket(12345)) {
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
try (Socket acceptedSocket = serverSocket.accept()) {
InputStream in = acceptedSocket.getInputStream();
int numRead;
byte[] buf = new byte[40000];
//byte[] buf = new byte[255 * 4 * 1000];
byte expectedNext = 1;
while ((numRead = in.read(buf)) != 0) {
if (Thread.interrupted()) {
log.info("Consumer thread interrupted");
return;
}
log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1]));
if (buf[numRead - 1] == 0) {
verifyData(buf, numRead - 1, expectedNext);
break;
}
expectedNext = verifyData(buf, numRead, expectedNext);
// Slow down consumer to test buffering.
Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS")));
}
log.info("Consumer read end of stream value: " + numRead);
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
}
} catch (Exception e) {
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
e.printStackTrace();
}
}

private byte verifyData(byte[] buf, int numRead, byte expectedNext) {
for (int i = 0; i < numRead; ++i) {
if (buf[i] != expectedNext) {
fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext);
}
if (++expectedNext == 0) {
expectedNext = 1;
}
}
return expectedNext;
}
}

private static class ProducerThread extends Thread {
private ProducerThread() {
super("Producer");
}

@Override
public void run() {
try (Socket clientSocket = new Socket("127.0.0.1", 8888);
OutputStream writer = clientSocket.getOutputStream()) {
byte[] buf = getData();
assertEquals(buf[0], 1);
assertEquals(buf[buf.length - 1], -1);
for (int i = 0; i < 1000; ++i) {
writer.write(buf);
if (Thread.interrupted()) {
log.info("Consumer thread interrupted");
return;
}
log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1]));
}
writer.write(0); // end of stream value
log.info("Producer finished sending data");
} catch (Exception e) {
e.printStackTrace();
}
}

private byte[] getData() {
byte[] buf = new byte[255 * 4 * 1000];
byte nextValue = 1;
for (int i = 0; i < buf.length; ++i) {
buf[i] = nextValue++;
// reserve 0 for end of stream
if (nextValue == 0) {
nextValue = 1;
}
}
return buf;
}
}

}
221 changes: 221 additions & 0 deletions src/test/java/net/schmizz/sshj/common/CircularBufferTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.common;

import static org.junit.jupiter.api.Assertions.*;

import net.schmizz.sshj.common.CircularBuffer.CircularBufferException;
import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer;
import org.junit.jupiter.api.Test;

public class CircularBufferTest {

@Test
public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE);

byte[] dataToWrite = getData(500);
buffer.putRawBytes(dataToWrite, 0, 100);
buffer.putRawBytes(dataToWrite, 100, 100);

byte[] dataToRead = new byte[500];
buffer.readRawBytes(dataToRead, 0, 80);
buffer.readRawBytes(dataToRead, 80, 80);

buffer.putRawBytes(dataToWrite, 200, 100);
buffer.readRawBytes(dataToRead, 160, 80);

buffer.putRawBytes(dataToWrite, 300, 100);
buffer.readRawBytes(dataToRead, 240, 80);

buffer.putRawBytes(dataToWrite, 400, 100);
buffer.readRawBytes(dataToRead, 320, 80);
buffer.readRawBytes(dataToRead, 400, 100);

assertEquals(256, buffer.length());
assertArrayEquals(dataToWrite, dataToRead);
}

@Test
public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);

byte[] dataToWrite = getData(500);
buffer.putRawBytes(dataToWrite, 0, 100);
buffer.putRawBytes(dataToWrite, 100, 100);

byte[] dataToRead = new byte[500];
buffer.readRawBytes(dataToRead, 0, 80);
buffer.readRawBytes(dataToRead, 80, 80);

buffer.putRawBytes(dataToWrite, 200, 100);
buffer.readRawBytes(dataToRead, 160, 80);

buffer.putRawBytes(dataToWrite, 300, 100);
buffer.readRawBytes(dataToRead, 240, 80);

buffer.putRawBytes(dataToWrite, 400, 100);
buffer.readRawBytes(dataToRead, 320, 80);

buffer.readRawBytes(dataToRead, 400, 100);

assertEquals(256, buffer.length());
assertArrayEquals(dataToWrite, dataToRead);
}

@Test
public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);

byte[] dataToWrite = getData(64);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end

assertEquals(64, buffer.available());
assertEquals(64 * 2, buffer.length());
}

@Test
public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);

// Move 1 byte forward.
buffer.putRawBytes(new byte[1], 0, 1);
buffer.readRawBytes(new byte[1], 0, 1);

// Force writes to wrap around.
byte[] dataToWrite = getData(64);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end

assertEquals(64, buffer.available());
assertEquals(64 * 2, buffer.length());
}

@Test
public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);

// Max capacity is always one less than the buffer size.
int maxCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1, maxCapacity);

byte[] dataToWrite = getData(maxCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);

assertEquals(dataToWrite.length, buffer.available());
assertEquals(64, buffer.length());
}

@Test
public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);

final int initiallyWritten = 10;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);

// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);

byte[] dataToWrite = getData(maxRemainingCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);

assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
assertEquals(64, buffer.length());
}

@Test
public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);

// Cause the internal write pointer to wrap around and be left of the read pointer.
final int initiallyWritten = 40;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);

// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);

byte[] dataToWrite = getData(maxRemainingCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);

assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
assertEquals(64, buffer.length());
}

@Test
public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);

final int initiallyWritten = 10;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);

// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);

byte[] dataToWrite = getData(maxRemainingCapacity + 1);
assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length));
}

@Test
public void shouldThrowWhenReadingEmptyBuffer() {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1));
}

@Test
public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
buffer.putRawBytes(new byte[1], 0, 1);
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2));
}

@Test
public void shouldThrowOnAboveMaximumInitialSize() {
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64));
}

@Test
public void shouldThrowOnMaximumInitialSize() {
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64));
}

@Test
public void shouldAllowFullCapacity() throws CircularBufferException {
int maxSize = 1024;
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
buffer.ensureCapacity(maxSize - 1);
assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity());
}

@Test
public void shouldThrowOnTooLargeRequestedCapacity() {
int maxSize = 1024;
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize));
}

private static byte[] getData(int length) {
byte[] data = new byte[length];
byte nextValue = 0;
for (int i = 0; i < length; ++i) {
data[i] = nextValue++;
}
return data;
}
}