diff --git a/README.md b/README.md index c3dc81f6..c2c4a75f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ repositories { } ext { - rlibVersion = "10.0.alpha7" + rlibVersion = "10.0.alpha8" } dependencies { diff --git a/build.gradle b/build.gradle index 8c95004e..44c3004c 100644 --- a/build.gradle +++ b/build.gradle @@ -1,4 +1,4 @@ -rootProject.version = "10.0.alpha7" +rootProject.version = "10.0.alpha8" group = 'javasabr.rlib' allprojects { diff --git a/rlib-collections/src/test/java/javasabr/rlib/collections/array/IntArrayTest.java b/rlib-collections/src/test/java/javasabr/rlib/collections/array/IntArrayTest.java index f5360900..f99f63e5 100644 --- a/rlib-collections/src/test/java/javasabr/rlib/collections/array/IntArrayTest.java +++ b/rlib-collections/src/test/java/javasabr/rlib/collections/array/IntArrayTest.java @@ -2,7 +2,6 @@ import static org.assertj.core.api.Assertions.assertThat; -import java.util.List; import java.util.stream.Stream; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; diff --git a/rlib-network/src/loadTest/java/javasabr/rlib/network/StringNetworkLoadTest.java b/rlib-network/src/loadTest/java/javasabr/rlib/network/StringNetworkLoadTest.java index 5a1b262b..d24ee458 100644 --- a/rlib-network/src/loadTest/java/javasabr/rlib/network/StringNetworkLoadTest.java +++ b/rlib-network/src/loadTest/java/javasabr/rlib/network/StringNetworkLoadTest.java @@ -41,7 +41,7 @@ private static class TestClient implements AutoCloseable { String clientId = "Client_%s".formatted(ID_FACTORY.incrementAndGet()); NetworkConfig networkConfig = NetworkConfig.SimpleNetworkConfig.builder() - .groupName(clientId) + .threadGroupName(clientId) .writeBufferSize(256) .readBufferSize(256) .pendingBufferSize(512) @@ -67,7 +67,7 @@ void connectAndSendMessages( ThreadUtils.sleep(random.nextInt(5000)); StringDataConnection connection = network.connect(serverAddress); - connection.onReceive((serverConnection, packet) -> statistics + connection.onReceiveValidPacket((serverConnection, packet) -> statistics .receivedServerPackersPerSecond() .accumulate(1)); @@ -141,7 +141,7 @@ void testServerWithMultiplyClients() { InetSocketAddress serverAddress = serverNetwork.start(); serverNetwork.onAccept(accepted -> accepted - .onReceive((connection, packet) -> { + .onReceiveValidPacket((connection, packet) -> { StringReadableNetworkPacket receivedPacket = (StringReadableNetworkPacket) packet; statistics .receivedClientPackersPerSecond() diff --git a/rlib-network/src/loadTest/java/javasabr/rlib/network/StringSslNetworkLoadTest.java b/rlib-network/src/loadTest/java/javasabr/rlib/network/StringSslNetworkLoadTest.java index 91a8120d..22e559ba 100644 --- a/rlib-network/src/loadTest/java/javasabr/rlib/network/StringSslNetworkLoadTest.java +++ b/rlib-network/src/loadTest/java/javasabr/rlib/network/StringSslNetworkLoadTest.java @@ -44,7 +44,7 @@ private static class TestClient implements AutoCloseable { String clientId = "Client_%s".formatted(ID_FACTORY.incrementAndGet()); NetworkConfig networkConfig = NetworkConfig.SimpleNetworkConfig.builder() - .groupName(clientId) + .threadGroupName(clientId) .writeBufferSize(256) .readBufferSize(256) .pendingBufferSize(512) @@ -70,7 +70,7 @@ void connectAndSendMessages( ThreadUtils.sleep(random.nextInt(5000)); StringDataSslConnection connection = network.connect(serverAddress); - connection.onReceive((serverConnection, packet) -> statistics + connection.onReceiveValidPacket((serverConnection, packet) -> statistics .receivedServerPackersPerSecond() .accumulate(1)); @@ -150,7 +150,7 @@ void testServerWithMultiplyClients() { InetSocketAddress serverAddress = serverNetwork.start(); serverNetwork.onAccept(accepted -> accepted - .onReceive((connection, packet) -> { + .onReceiveValidPacket((connection, packet) -> { StringReadableNetworkPacket receivedPacket = (StringReadableNetworkPacket) packet; statistics .receivedClientPackersPerSecond() diff --git a/rlib-network/src/main/java/javasabr/rlib/network/Connection.java b/rlib-network/src/main/java/javasabr/rlib/network/Connection.java index 52b3040b..770a6c52 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/Connection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/Connection.java @@ -13,10 +13,10 @@ */ public interface Connection> { - record ReceivedPacketEvent(C connection, R packet) { + record ReceivedPacketEvent(C connection, R packet, boolean valid) { @Override public String toString() { - return "[" + connection + "|" + packet + ']'; + return "[" + connection + '|' + packet + '|' + valid + ']'; } } @@ -53,9 +53,14 @@ public String toString() { CompletableFuture sendWithFeedback(WritableNetworkPacket packet); /** - * Register a consumer to handle received packets. + * Register a consumer to handle received valid packets. */ - void onReceive(BiConsumer> consumer); + void onReceiveValidPacket(BiConsumer> consumer); + + /** + * Register a consumer to handle received invalid packets. + */ + void onReceiveInvalidPacket(BiConsumer> consumer); /** * Get a stream of received packet events. @@ -72,15 +77,29 @@ default > Flux> rec } /** - * Get a stream of received packets. + * Get a stream of received valid packets. + */ + Flux> receivedValidPackets(); + + /** + * Get a stream of received invalid packets. */ - Flux> receivedPackets(); + Flux> receivedInvalidPackets(); + + /** + * Get a stream of received valid packets with expected type. + */ + default > Flux receivedValidPackets(Class packetType) { + return receivedValidPackets() + .filter(packetType::isInstance) + .map(networkPacket -> (R) networkPacket); + } /** - * Get a stream of received packets with expected type. + * Get a stream of received invalid packets with expected type. */ - default > Flux receivedPackets(Class packetType) { - return receivedPackets() + default > Flux receivedInvalidPackets(Class packetType) { + return receivedInvalidPackets() .filter(packetType::isInstance) .map(networkPacket -> (R) networkPacket); } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/Network.java b/rlib-network/src/main/java/javasabr/rlib/network/Network.java index 14030f12..3c8b8443 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/Network.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/Network.java @@ -13,6 +13,8 @@ public interface Network> { NetworkConfig config(); + void inNetworkThread(Runnable task); + /** * Shutdown this network. */ diff --git a/rlib-network/src/main/java/javasabr/rlib/network/NetworkConfig.java b/rlib-network/src/main/java/javasabr/rlib/network/NetworkConfig.java index 05db4ee1..9b45cc62 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/NetworkConfig.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/NetworkConfig.java @@ -1,6 +1,7 @@ package javasabr.rlib.network; import java.nio.ByteOrder; +import javasabr.rlib.common.util.GroupThreadFactory.ThreadConstructor; import lombok.Builder; import lombok.Getter; import lombok.experimental.Accessors; @@ -18,9 +19,13 @@ public interface NetworkConfig { class SimpleNetworkConfig implements NetworkConfig { @Builder.Default - private String groupName = "NetworkThread"; + private String threadGroupName = "NetworkThread"; @Builder.Default private ByteOrder byteOrder = ByteOrder.BIG_ENDIAN; + @Builder.Default + private ThreadConstructor threadConstructor = Thread::new; + @Builder.Default + private int threadPriority = Thread.NORM_PRIORITY; @Builder.Default private int readBufferSize = 2048; @@ -46,6 +51,22 @@ public String threadGroupName() { } }; + /** + * Get a thread constructor which should be used to create network threads. + */ + default ThreadConstructor threadConstructor() { + return Thread::new; + } + + /** + * Get a priority of network threads. + * + * @return the priority of network threads. + */ + default int threadPriority() { + return Thread.NORM_PRIORITY; + } + /** * Get a group name of network threads. */ diff --git a/rlib-network/src/main/java/javasabr/rlib/network/ServerNetworkConfig.java b/rlib-network/src/main/java/javasabr/rlib/network/ServerNetworkConfig.java index 56576270..ab090343 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/ServerNetworkConfig.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/ServerNetworkConfig.java @@ -82,21 +82,4 @@ default int threadGroupMaxSize() { default int scheduledThreadGroupSize() { return 1; } - - - /** - * Get a thread constructor which should be used to create network threads. - */ - default ThreadConstructor threadConstructor() { - return Thread::new; - } - - /** - * Get a priority of network threads. - * - * @return the priority of network threads. - */ - default int threadPriority() { - return Thread.NORM_PRIORITY; - } } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/client/impl/DefaultClientNetwork.java b/rlib-network/src/main/java/javasabr/rlib/network/client/impl/DefaultClientNetwork.java index c2028803..69997992 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/client/impl/DefaultClientNetwork.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/client/impl/DefaultClientNetwork.java @@ -2,10 +2,12 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.channels.AsynchronousChannelGroup; import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.CompletionHandler; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; @@ -43,6 +45,8 @@ public class DefaultClientNetwork> extends AbstractNetwo @Getter final ScheduledExecutorService scheduledExecutor; + final ExecutorService networkExecutor; + final AsynchronousChannelGroup channelGroup; @Nullable @Getter(AccessLevel.PROTECTED) @@ -57,11 +61,17 @@ public DefaultClientNetwork( BiFunction, AsynchronousSocketChannel, C> channelToConnection) { super(config, channelToConnection); this.connecting = new AtomicBoolean(false); - this.scheduledExecutor = Executors - .newSingleThreadScheduledExecutor(new GroupThreadFactory(config.scheduledThreadGroupName())); + this.scheduledExecutor = buildScheduledExecutor(config); + this.networkExecutor = buildExecutor(config); + this.channelGroup = Utils.uncheckedGet(networkExecutor, AsynchronousChannelGroup::withThreadPool); log.info(config, DefaultClientNetwork::buildConfigDescription); } + @Override + public void inNetworkThread(Runnable task) { + networkExecutor.execute(task); + } + @Override public C connect(InetSocketAddress serverAddress) { return connectAsync(serverAddress).join(); @@ -88,7 +98,7 @@ public CompletableFuture connectAsync(InetSocketAddress serverAddress) { var asyncResult = new CompletableFuture(); @SuppressWarnings("resource") - var channel = Utils.uncheckedGet(AsynchronousSocketChannel::open); + var channel = Utils.uncheckedGet(channelGroup, AsynchronousSocketChannel::open); channel.connect(serverAddress, this, new CompletionHandler<>() { @Override public void completed(@Nullable Void result, DefaultClientNetwork network) { @@ -136,6 +146,33 @@ public void shutdown() { if (connection != null) { Utils.unchecked(connection, C::close); } + channelGroup.shutdown(); + scheduledExecutor.shutdown(); + networkExecutor.shutdown(); + } + + protected ExecutorService buildExecutor(NetworkConfig config) { + var threadFactory = new GroupThreadFactory( + config.threadGroupName(), + config.threadConstructor(), + config.threadPriority(), + false); + ExecutorService executorService = Executors.newSingleThreadScheduledExecutor(threadFactory); + // activate the executor + executorService.submit(() -> {}); + return executorService; + } + + protected ScheduledExecutorService buildScheduledExecutor(NetworkConfig config) { + var threadFactory = new GroupThreadFactory( + config.scheduledThreadGroupName(), + config.threadConstructor(), + config.threadPriority(), + false); + ScheduledExecutorService scheduledExecutor = Executors.newSingleThreadScheduledExecutor(threadFactory); + // activate the executor + scheduledExecutor.submit(() -> {}); + return scheduledExecutor; } private static String buildConfigDescription(NetworkConfig conf) { diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/MalformedProtocolException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/MalformedProtocolException.java index fdd1f74b..4441f294 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/exception/MalformedProtocolException.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/MalformedProtocolException.java @@ -1,6 +1,6 @@ package javasabr.rlib.network.exception; -public class MalformedProtocolException extends RuntimeException { +public class MalformedProtocolException extends NetworkException { public MalformedProtocolException(String message) { super(message); } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/NetworkException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/NetworkException.java new file mode 100644 index 00000000..0fe81c62 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/NetworkException.java @@ -0,0 +1,16 @@ +package javasabr.rlib.network.exception; + +public class NetworkException extends RuntimeException { + + protected NetworkException(String message) { + super(message); + } + + protected NetworkException(String message, Throwable cause) { + super(message, cause); + } + + protected NetworkException(Throwable cause) { + super(cause); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/exception/UserDefinedNetworkException.java b/rlib-network/src/main/java/javasabr/rlib/network/exception/UserDefinedNetworkException.java new file mode 100644 index 00000000..47932ef1 --- /dev/null +++ b/rlib-network/src/main/java/javasabr/rlib/network/exception/UserDefinedNetworkException.java @@ -0,0 +1,15 @@ +package javasabr.rlib.network.exception; + +public class UserDefinedNetworkException extends NetworkException { + protected UserDefinedNetworkException(String message) { + super(message); + } + + protected UserDefinedNetworkException(String message, Throwable cause) { + super(message, cause); + } + + protected UserDefinedNetworkException(Throwable cause) { + super(cause); + } +} diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java index 21b5c67e..8adebcac 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/AbstractConnection.java @@ -62,7 +62,8 @@ public WritablePacketWithFeedback(CompletableFuture attachment, Writabl final StampedLock lock; final AtomicBoolean closed; - final MutableArray>> subscribers; + final MutableArray>> validPacketSubscribers; + final MutableArray>> invalidPacketSubscribers; final int maxPacketsByRead; @@ -81,7 +82,8 @@ public AbstractConnection( this.pendingPackets = DequeFactory.arrayBasedBased(WritableNetworkPacket.class); this.network = network; this.closed = new AtomicBoolean(false); - this.subscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); + this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); + this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class); this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel)); } @@ -93,9 +95,15 @@ public void onConnected() {} protected abstract NetworkPacketWriter packetWriter(); @Override - public void onReceive(BiConsumer> consumer) { - subscribers.add(consumer); - packetReader().startRead(); + public void onReceiveValidPacket(BiConsumer> consumer) { + validPacketSubscribers.add(consumer); + network.inNetworkThread(() -> packetReader().startRead()); + } + + @Override + public void onReceiveInvalidPacket(BiConsumer> consumer) { + invalidPacketSubscribers.add(consumer); + network.inNetworkThread(() -> packetReader().startRead()); } @Override @@ -104,26 +112,49 @@ public void onReceive(BiConsumer> consumer) } @Override - public Flux> receivedPackets() { - return Flux.create(this::registerFluxOnReceivedPackets); + public Flux> receivedValidPackets() { + return Flux.create(this::registerFluxOnReceivedValidPackets); + } + + @Override + public Flux> receivedInvalidPackets() { + return Flux.create(this::registerFluxOnReceivedInvalidPackets); } protected void registerFluxOnReceivedEvents( FluxSink>> sink) { - BiConsumer> listener = + BiConsumer> validListener = (connection, packet) -> sink.next(new ReceivedPacketEvent<>(connection, - packet)); + packet, true)); + BiConsumer> invalidListener = + (connection, packet) -> sink.next(new ReceivedPacketEvent<>(connection, + packet, false)); - onReceive(listener); - sink.onDispose(() -> subscribers.remove(listener)); + validPacketSubscribers.add(validListener); + invalidPacketSubscribers.add(invalidListener); + + sink.onDispose(() -> { + validPacketSubscribers.remove(validListener); + validPacketSubscribers.remove(invalidListener); + }); + + network.inNetworkThread(() -> packetReader().startRead()); + } + + protected void registerFluxOnReceivedValidPackets(FluxSink> sink) { + BiConsumer> listener = (connection, packet) -> sink.next(packet); + validPacketSubscribers.add(listener); + sink.onDispose(() -> validPacketSubscribers.remove(listener)); + network.inNetworkThread(() -> packetReader().startRead()); } - protected void registerFluxOnReceivedPackets(FluxSink> sink) { + protected void registerFluxOnReceivedInvalidPackets(FluxSink> sink) { BiConsumer> listener = (connection, packet) -> sink.next(packet); - onReceive(listener); - sink.onDispose(() -> subscribers.remove(listener)); + invalidPacketSubscribers.add(listener); + sink.onDispose(() -> invalidPacketSubscribers.remove(listener)); + network.inNetworkThread(() -> packetReader().startRead()); } @Nullable @@ -169,9 +200,16 @@ public boolean closed() { protected void serializedPacket(WritableNetworkPacket packet) {} - protected void handleReceivedPacket(ReadableNetworkPacket packet) { - log.debug(packet, remoteAddress, "Handle received packet:[%s] from:[%s]"::formatted); - subscribers + protected void handleReceivedValidPacket(ReadableNetworkPacket packet) { + log.debug(packet, remoteAddress, "Handle received valid packet:[%s] from:[%s]"::formatted); + validPacketSubscribers + .iterations() + .forEach((C) this, packet, BiConsumer::accept); + } + + protected void handleReceivedInvalidPacket(ReadableNetworkPacket packet) { + log.debug(packet, remoteAddress, "Handle failed received packet:[%s] from:[%s]"::formatted); + invalidPacketSubscribers .iterations() .forEach((C) this, packet, BiConsumer::accept); } @@ -190,18 +228,15 @@ public final void send(WritableNetworkPacket packet) { } protected void sendImpl(WritableNetworkPacket packet) { - if (closed()) { return; } - long stamp = lock.writeLock(); try { pendingPackets.addLast(packet); } finally { lock.unlockWrite(stamp); } - packetWriter().tryToSendNextPacket(); } @@ -216,15 +251,11 @@ protected void queueAtFirst(WritableNetworkPacket packet) { @Override public CompletableFuture sendWithFeedback(WritableNetworkPacket packet) { - var asyncResult = new CompletableFuture(); - sendImpl(new WritablePacketWithFeedback<>(asyncResult, packet)); - if (closed()) { return CompletableFuture.completedFuture(Boolean.FALSE); } - return asyncResult; } @@ -241,11 +272,9 @@ protected void clearWaitPackets() { } protected void doClearWaitPackets() { - for (var pendingPacket : pendingPackets) { handleSentPacket(pendingPacket, false); } - pendingPackets.clear(); } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataConnection.java index d59701d9..e99db5d8 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataConnection.java @@ -44,7 +44,8 @@ protected NetworkPacketReader createPacketReader() { return new DefaultNetworkPacketReader<>( (C) this, this::updateLastActivity, - this::handleReceivedPacket, + this::handleReceivedValidPacket, + this::handleReceivedInvalidPacket, value -> createReadablePacket(), packetLengthHeaderSize, maxPacketsByRead); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataSslConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataSslConnection.java index facfe8ba..da766124 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataSslConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/DefaultDataSslConnection.java @@ -45,7 +45,8 @@ protected NetworkPacketReader createPacketReader() { return new DefaultSslNetworkPacketReader<>( (C) this, this::updateLastActivity, - this::handleReceivedPacket, + this::handleReceivedValidPacket, + this::handleReceivedInvalidPacket, value -> createReadablePacket(), sslEngine, this::sendImpl, diff --git a/rlib-network/src/main/java/javasabr/rlib/network/impl/IdBasedPacketConnection.java b/rlib-network/src/main/java/javasabr/rlib/network/impl/IdBasedPacketConnection.java index 7bf9be7e..7d9c8fb6 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/impl/IdBasedPacketConnection.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/impl/IdBasedPacketConnection.java @@ -50,7 +50,8 @@ protected NetworkPacketReader createPacketReader() { return new IdBasedNetworkPacketReader<>( (C) this, this::updateLastActivity, - this::handleReceivedPacket, + this::handleReceivedValidPacket, + this::handleReceivedInvalidPacket, packetLengthHeaderSize, maxPacketsByRead, packetIdHeaderSize, diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java index cba2669d..de7e5b8a 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractNetworkPacketReader.java @@ -62,7 +62,8 @@ public void failed(Throwable exc, ByteBuffer readingBuffer) { final ByteBuffer pendingBuffer; final Runnable updateActivityFunction; - final Consumer packetHandler; + final Consumer validPacketHandler; + final Consumer invalidPacketHandler; @Getter(AccessLevel.PROTECTED) @Setter(AccessLevel.PROTECTED) @@ -74,13 +75,15 @@ public void failed(Throwable exc, ByteBuffer readingBuffer) { protected AbstractNetworkPacketReader( C connection, Runnable updateActivityFunction, - Consumer packetHandler, + Consumer validPacketHandler, + Consumer invalidPacketHandler, int maxPacketsByRead) { this.connection = connection; this.readBuffer = connection.bufferAllocator().takeReadBuffer(); this.pendingBuffer = connection.bufferAllocator().takePendingBuffer(); this.updateActivityFunction = updateActivityFunction; - this.packetHandler = packetHandler; + this.validPacketHandler = validPacketHandler; + this.invalidPacketHandler = invalidPacketHandler; this.maxPacketsByRead = maxPacketsByRead; } @@ -321,10 +324,9 @@ else if (packetFullLength > tempBigBuffer.capacity()) { protected void readAndHandlePacket(ByteBuffer bufferToRead, int remainingDataLength, R packetInstance) { if (packetInstance.read(connection, bufferToRead, remainingDataLength)) { - packetHandler.accept(packetInstance); + validPacketHandler.accept(packetInstance); } else { - log.error(remoteAddress(), packetInstance, - "[%s] Packet:[%s] was read incorrectly"::formatted); + invalidPacketHandler.accept(packetInstance); } } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java index 0c8b88b3..6ab75309 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/AbstractSslNetworkPacketReader.java @@ -50,11 +50,12 @@ public abstract class AbstractSslNetworkPacketReader< protected AbstractSslNetworkPacketReader( C connection, Runnable updateActivityFunction, - Consumer readPacketHandler, + Consumer validPacketHandler, + Consumer invalidPacketHandler, SSLEngine sslEngine, Consumer> packetWriter, int maxPacketsByRead) { - super(connection, updateActivityFunction, readPacketHandler, maxPacketsByRead); + super(connection, updateActivityFunction, validPacketHandler, invalidPacketHandler, maxPacketsByRead); BufferAllocator bufferAllocator = connection.bufferAllocator(); this.sslEngine = sslEngine; this.sslDataBuffer = bufferAllocator.takeBuffer(sslEngine diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultNetworkPacketReader.java index f5f0b320..9df519e1 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultNetworkPacketReader.java @@ -23,11 +23,12 @@ public class DefaultNetworkPacketReader< public DefaultNetworkPacketReader( C connection, Runnable updateActivityFunction, - Consumer packetHandler, + Consumer validPacketHandler, + Consumer invalidPacketHandler, IntFunction readablePacketFactory, int packetLengthHeaderSize, int maxPacketsByRead) { - super(connection, updateActivityFunction, packetHandler, maxPacketsByRead); + super(connection, updateActivityFunction, validPacketHandler, invalidPacketHandler, maxPacketsByRead); this.readablePacketFactory = readablePacketFactory; this.packetLengthHeaderSize = packetLengthHeaderSize; } diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultSslNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultSslNetworkPacketReader.java index 461aa78b..c2e6a450 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultSslNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/DefaultSslNetworkPacketReader.java @@ -25,7 +25,8 @@ public class DefaultSslNetworkPacketReader< public DefaultSslNetworkPacketReader( C connection, Runnable updateActivityFunction, - Consumer packetHandler, + Consumer validPacketHandler, + Consumer invalidPacketHandler, IntFunction packetResolver, SSLEngine sslEngine, Consumer> packetWriter, @@ -34,7 +35,8 @@ public DefaultSslNetworkPacketReader( super( connection, updateActivityFunction, - packetHandler, + validPacketHandler, + invalidPacketHandler, sslEngine, packetWriter, maxPacketsByRead); diff --git a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/IdBasedNetworkPacketReader.java b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/IdBasedNetworkPacketReader.java index 55c34ff9..3aefe9b7 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/IdBasedNetworkPacketReader.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/packet/impl/IdBasedNetworkPacketReader.java @@ -24,12 +24,13 @@ public class IdBasedNetworkPacketReader< public IdBasedNetworkPacketReader( C connection, Runnable updateActivityFunction, - Consumer packetHandler, + Consumer validPacketHandler, + Consumer invalidPacketHandler, int packetLengthHeaderSize, int maxPacketsByRead, int packetIdHeaderSize, ReadableNetworkPacketRegistry packetRegistry) { - super(connection, updateActivityFunction, packetHandler, maxPacketsByRead); + super(connection, updateActivityFunction, validPacketHandler, invalidPacketHandler, maxPacketsByRead); this.packetLengthHeaderSize = packetLengthHeaderSize; this.packetIdHeaderSize = packetIdHeaderSize; this.packetRegistry = packetRegistry; diff --git a/rlib-network/src/main/java/javasabr/rlib/network/server/impl/DefaultServerNetwork.java b/rlib-network/src/main/java/javasabr/rlib/network/server/impl/DefaultServerNetwork.java index 7a932229..4f382367 100644 --- a/rlib-network/src/main/java/javasabr/rlib/network/server/impl/DefaultServerNetwork.java +++ b/rlib-network/src/main/java/javasabr/rlib/network/server/impl/DefaultServerNetwork.java @@ -57,7 +57,7 @@ private interface ServerCompletionHandler> extends public void completed(AsynchronousSocketChannel channel, DefaultServerNetwork network) { var connection = network.channelToConnection.apply(network, channel); log.debug(connection.remoteAddress(), "Accepted new connection:[%s]"::formatted); - network.onAccept(connection); + network.consumeAccepted(connection); network.acceptNext(); } @@ -77,8 +77,9 @@ public void failed(Throwable exc, DefaultServerNetwork network) { @Getter ScheduledExecutorService scheduledExecutor; + ExecutorService networkExecutor; - AsynchronousChannelGroup group; + AsynchronousChannelGroup channelGroup; AsynchronousServerSocketChannel channel; MutableArray> subscribers; @@ -86,9 +87,10 @@ public DefaultServerNetwork( ServerNetworkConfig config, BiFunction, AsynchronousSocketChannel, C> channelToConnection) { super(config, channelToConnection); - this.group = Utils.uncheckedGet(buildExecutor(config), AsynchronousChannelGroup::withThreadPool); + this.networkExecutor = buildExecutor(config); + this.channelGroup = Utils.uncheckedGet(networkExecutor, AsynchronousChannelGroup::withThreadPool); this.scheduledExecutor = buildScheduledExecutor(config); - this.channel = Utils.uncheckedGet(group, AsynchronousServerSocketChannel::open); + this.channel = Utils.uncheckedGet(channelGroup, AsynchronousServerSocketChannel::open); this.subscribers = ArrayFactory.copyOnModifyArray(Consumer.class); log.info(config, DefaultServerNetwork::buildConfigDescription); } @@ -109,18 +111,23 @@ public InetSocketAddress start() { log.info(address, "Started server socket on address:[%s]"::formatted); if (!subscribers.isEmpty()) { - acceptNext(); + inNetworkThread(this::acceptNext); } return address; } + @Override + public void inNetworkThread(Runnable task) { + networkExecutor.execute(task); + } + @Override public > S start(InetSocketAddress serverAddress) { Utils.unchecked(channel, serverAddress, AsynchronousServerSocketChannel::bind); log.info(serverAddress, addr -> "Started server socket on address: " + addr); if (!subscribers.isEmpty()) { - acceptNext(); + inNetworkThread(this::acceptNext); } return ClassUtils.unsafeNNCast(this); } @@ -135,7 +142,7 @@ protected void acceptNext() { } } - protected void onAccept(C connection) { + protected void consumeAccepted(C connection) { connection.onConnected(); subscribers .iterations() @@ -145,7 +152,7 @@ protected void onAccept(C connection) { @Override public void onAccept(Consumer consumer) { subscribers.add(consumer); - acceptNext(); + inNetworkThread(this::acceptNext); } @Override @@ -162,7 +169,9 @@ protected void registerFluxOnAccepted(FluxSink sink) { @Override public void shutdown() { Utils.unchecked(channel, AsynchronousChannel::close); - group.shutdown(); + channelGroup.shutdown(); + scheduledExecutor.shutdown(); + networkExecutor.shutdown(); } protected ExecutorService buildExecutor(ServerNetworkConfig config) { diff --git a/rlib-network/src/test/java/javasabr/rlib/network/BaseNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/BaseNetworkTest.java index 47e9bdf1..67649695 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/BaseNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/BaseNetworkTest.java @@ -59,12 +59,20 @@ public CompletableFuture sendWithFeedback(WritableNetworkPacket packet) } @Override - public Flux receivedPackets() { + public Flux> receivedValidPackets() { return Flux.empty(); } @Override - public void onReceive(BiConsumer consumer) {} + public Flux> receivedInvalidPackets() { + return Flux.empty(); + } + + @Override + public void onReceiveValidPacket(BiConsumer> consumer) {} + + @Override + public void onReceiveInvalidPacket(BiConsumer> consumer) {} } public static final MockConnection MOCK_CONNECTION = new MockConnection(); diff --git a/rlib-network/src/test/java/javasabr/rlib/network/DefaultNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/DefaultNetworkTest.java index 707d636e..9501789b 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/DefaultNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/DefaultNetworkTest.java @@ -263,7 +263,7 @@ public ByteBuffer takeBuffer(int bufferSize) { DefaultConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets() + .receivedValidPackets() .buffer(packetCount); List messages = IntStream diff --git a/rlib-network/src/test/java/javasabr/rlib/network/HandlingValidAndInvalidReceivedPacketsTest.java b/rlib-network/src/test/java/javasabr/rlib/network/HandlingValidAndInvalidReceivedPacketsTest.java new file mode 100644 index 00000000..6edd39a7 --- /dev/null +++ b/rlib-network/src/test/java/javasabr/rlib/network/HandlingValidAndInvalidReceivedPacketsTest.java @@ -0,0 +1,127 @@ +package javasabr.rlib.network; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; +import javasabr.rlib.logger.api.LoggerLevel; +import javasabr.rlib.logger.api.LoggerManager; +import javasabr.rlib.network.annotation.NetworkPacketDescription; +import javasabr.rlib.network.impl.DefaultConnection; +import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket; +import javasabr.rlib.network.packet.impl.DefaultWritableNetworkPacket; +import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry; +import javasabr.rlib.network.server.ServerNetwork; +import lombok.CustomLog; +import lombok.RequiredArgsConstructor; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +@CustomLog +public class HandlingValidAndInvalidReceivedPacketsTest extends BaseNetworkTest { + + static { + LoggerManager.enable(HandlingValidAndInvalidReceivedPacketsTest.class, LoggerLevel.INFO); + //LoggerManager.enable(AbstractConnection.class, LoggerLevel.DEBUG); + } + + // client packets + interface ClientPackets { + + @RequiredArgsConstructor + @NetworkPacketDescription(id = 1) + class TestValidatablePacket extends DefaultWritableNetworkPacket { + + private final boolean valid; + + @Override + protected void writeImpl(DefaultConnection connection, ByteBuffer buffer) { + super.writeImpl(connection, buffer); + writeByte(buffer, valid ? 1 : 0); + } + } + + @NetworkPacketDescription(id = 2) + class MockReadablePacket extends DefaultReadableNetworkPacket {} + } + + // server packets + interface ServerPackets { + + @NetworkPacketDescription(id = 1) + class TestValidatablePacket extends DefaultReadableNetworkPacket { + + @Override + protected void readImpl(DefaultConnection connection, ByteBuffer buffer) { + super.readImpl(connection, buffer); + if (readByte(buffer) == 0) { + throw new RuntimeException("Received invalid packet"); + } + } + } + } + + @Test + void shouldCorrectlyReceiveValidAndInvalidPackets() throws InterruptedException { + // given: + ReadableNetworkPacketRegistry, DefaultConnection> serverPackets = + ReadableNetworkPacketRegistry.of(DefaultReadableNetworkPacket.class, DefaultConnection.class, ServerPackets.TestValidatablePacket.class); + ReadableNetworkPacketRegistry, DefaultConnection> clientPackets = + ReadableNetworkPacketRegistry.of(DefaultReadableNetworkPacket.class, DefaultConnection.class, ClientPackets.MockReadablePacket.class); + + ServerNetwork serverNetwork = NetworkFactory.defaultServerNetwork(serverPackets); + InetSocketAddress serverAddress = serverNetwork.start(); + List receivedValidPackets = Collections + .synchronizedList(new ArrayList<>()); + List receivedInvalidPackets = Collections + .synchronizedList(new ArrayList<>()); + + var counter = new CountDownLatch(30); + + // when: + serverNetwork + .accepted() + .flatMap(connection -> connection.receivedEvents(ServerPackets.TestValidatablePacket.class)) + .doOnNext(event -> { + var packet = event.packet(); + if (event.valid()) { + receivedValidPackets.add(packet); + } else { + receivedInvalidPackets.add(packet); + } + counter.countDown(); + }) + .subscribe(event -> log.info(event, "Received from client:[%s]"::formatted)); + + var clientNetwork = NetworkFactory.defaultClientNetwork(clientPackets); + clientNetwork + .connectReactive(serverAddress) + .doOnNext(connection -> IntStream + .range(0, 30) + .forEach(length -> { + if (length % 5 == 0) { + connection.send(new ClientPackets.TestValidatablePacket(false)); + } else { + connection.send(new ClientPackets.TestValidatablePacket(true)); + } + })) + .subscribe(); + + Assertions.assertTrue( + counter.await(10000, TimeUnit.MILLISECONDS), + "Still wait for " + counter.getCount() + " packets..."); + + clientNetwork.shutdown(); + serverNetwork.shutdown(); + + // then: + assertThat(receivedInvalidPackets).hasSize(6); + assertThat(receivedValidPackets).hasSize(24); + } +} diff --git a/rlib-network/src/test/java/javasabr/rlib/network/StringNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/StringNetworkTest.java index 8de053a2..2a88489c 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/StringNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/StringNetworkTest.java @@ -119,7 +119,7 @@ public ByteBuffer takeBuffer(int bufferSize) { StringDataConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets(RECEIVED_PACKET_TYPE) + .receivedValidPackets(RECEIVED_PACKET_TYPE) .buffer(packetCount); List messages = IntStream @@ -165,7 +165,7 @@ void shouldReceiveManyPacketsFromSmallToBigSize() { StringDataConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets(RECEIVED_PACKET_TYPE) + .receivedValidPackets(RECEIVED_PACKET_TYPE) .doOnNext(packet -> log.info(packet.data().length(), "Received [%s] symbols from client"::formatted)) .buffer(packetCount); @@ -223,7 +223,7 @@ void shouldSendBiggerPacketThanWriteBuffer() { StringDataConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets(RECEIVED_PACKET_TYPE) + .receivedValidPackets(RECEIVED_PACKET_TYPE) .doOnNext(packet -> log.info(packet.data().length(), "Received [%s] symbols from client"::formatted)) .buffer(packetCount); @@ -352,7 +352,7 @@ void testServerWithMultiplyClientsUsingOldApi() { var connectedClients = new CountDownLatch(clientCount); serverNetwork.onAccept(connection -> { - connection.onReceive((con, packet) -> { + connection.onReceiveValidPacket((con, packet) -> { receivedPacketsOnServer.incrementAndGet(); con.send(newMessage(minMessageLength, maxMessageLength)); }); @@ -371,7 +371,7 @@ void testServerWithMultiplyClientsUsingOldApi() { .stream() .map(ClientNetwork::currentConnection) .filter(Objects::nonNull) - .peek(connection -> connection.onReceive((con, packet) -> { + .peek(connection -> connection.onReceiveValidPacket((con, packet) -> { receivedPacketsOnClients.incrementAndGet(); counter.countDown(); })) @@ -403,7 +403,7 @@ void shouldGetAllPacketWithFeedback() { StringDataConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets() + .receivedValidPackets() .buffer(packetCount); List> asyncResults = IntStream diff --git a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java index b32ac629..72389f34 100644 --- a/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java +++ b/rlib-network/src/test/java/javasabr/rlib/network/StringSslNetworkTest.java @@ -293,7 +293,7 @@ void shouldReceiveManyPacketsFromSmallToBigSize() { StringDataSslConnection serverToClient = testNetwork.serverToClient; var pendingPacketsOnServer = serverToClient - .receivedPackets() + .receivedValidPackets() .doOnNext(packet -> log.info("Received from client: " + packet)) .buffer(packetCount);