diff --git a/src/main/java/org/openrewrite/polyglot/ProgressBar.java b/src/main/java/org/openrewrite/polyglot/ProgressBar.java index 63d5ba1..9a93793 100644 --- a/src/main/java/org/openrewrite/polyglot/ProgressBar.java +++ b/src/main/java/org/openrewrite/polyglot/ProgressBar.java @@ -32,4 +32,20 @@ public interface ProgressBar extends AutoCloseable { ProgressBar setExtraMessage(String extraMessage); ProgressBar setMax(int max); + + /** + * Set the canceled state of the progress bar. + * @param canceled true if the operation has been canceled + */ + default void setCanceled(boolean canceled) { + // Default no-op implementation for backward compatibility + } + + /** + * Check if the progress bar has been marked as canceled. + * @return true if the operation has been canceled + */ + default boolean isCanceled() { + return false; + } } diff --git a/src/main/java/org/openrewrite/polyglot/RemoteProgressBarReceiver.java b/src/main/java/org/openrewrite/polyglot/RemoteProgressBarReceiver.java index 71cd0e8..96d37f1 100644 --- a/src/main/java/org/openrewrite/polyglot/RemoteProgressBarReceiver.java +++ b/src/main/java/org/openrewrite/polyglot/RemoteProgressBarReceiver.java @@ -19,8 +19,7 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.net.DatagramSocket; -import java.net.SocketException; +import java.net.*; import java.util.LinkedHashMap; import java.util.Map; import java.util.UUID; @@ -36,7 +35,11 @@ public class RemoteProgressBarReceiver implements ProgressBar { private final ProgressBar delegate; private final DatagramSocket socket; private volatile boolean closed; - private final AtomicReference thrown = new AtomicReference<>(); + private final AtomicReference<@Nullable String> thrown = new AtomicReference<>(); + private volatile boolean canceled = false; + private @Nullable InetAddress lastSenderAddress; + private int lastSenderPort; + private volatile boolean cancelNotificationSent = false; public RemoteProgressBarReceiver(ProgressBar delegate) { try { @@ -61,28 +64,47 @@ protected boolean removeEldestEntry(Map.Entry eldes }; try { while (!closed) { - RemoteProgressMessage message = RemoteProgressMessage.receive(socket, incompleteMessages); - if (message == null) { - continue; - } - switch (message.getType()) { - case Exception: - if (message.getMessage() != null) { - thrown.set(message.getMessage()); - } - break; - case IntermediateResult: - delegate.intermediateResult(message.getMessage()); - break; - case Step: - delegate.step(); - break; - case SetExtraMessage: - delegate.setExtraMessage(requireNonNull(message.getMessage())); - break; - case SetMax: - delegate.setMax(Integer.parseInt(requireNonNull(message.getMessage()))); - break; + // Receive with packet info to get sender details + byte[] buf = new byte[128]; + DatagramPacket packet = new DatagramPacket(buf, 128); + try { + socket.receive(packet); + + // Store sender info for sending cancel status back + lastSenderAddress = packet.getAddress(); + lastSenderPort = packet.getPort(); + + RemoteProgressMessage message = RemoteProgressMessage.read(buf, packet.getLength(), incompleteMessages); + if (message == null) { + continue; + } + switch (message.getType()) { + case Exception: + if (message.getMessage() != null) { + thrown.set(message.getMessage()); + } + break; + case IntermediateResult: + delegate.intermediateResult(message.getMessage()); + break; + case Step: + delegate.step(); + break; + case SetExtraMessage: + delegate.setExtraMessage(requireNonNull(message.getMessage())); + break; + case SetMax: + delegate.setMax(Integer.parseInt(requireNonNull(message.getMessage()))); + break; + } + + // Only send cancel status if we haven't already notified about cancellation + if ((canceled || delegate.isCanceled()) && !cancelNotificationSent) { + sendCancelStatus(); + cancelNotificationSent = true; + } + } catch (SocketTimeoutException ignored) { + // No message received, continue } } } catch (IOException e) { @@ -136,4 +158,48 @@ private void maybeThrow() { throw RemoteException.decode(t); } } + + private void sendCancelStatus() { + if (lastSenderAddress != null && lastSenderPort > 0) { + try { + // Send a cancel notification message + String cancelMessage = "CANCEL:true"; + byte[] cancelBytes = cancelMessage.getBytes(); + DatagramPacket cancelPacket = new DatagramPacket( + cancelBytes, + cancelBytes.length, + lastSenderAddress, + lastSenderPort + ); + + // Try a few times to ensure delivery (since we only send once per cancellation) + for (int i = 0; i < 3; i++) { + socket.send(cancelPacket); + if (i < 2) { + Thread.sleep(10); // Small delay between retries + } + } + } catch (IOException | InterruptedException ignored) { + // Ignore failures when sending cancel status + } + } + } + + @Override + public void setCanceled(boolean canceled) { + boolean wasNotCanceled = !this.canceled; + this.canceled = canceled; + delegate.setCanceled(canceled); + + // If we just became canceled and haven't sent notification yet, send it + if (wasNotCanceled && canceled && !cancelNotificationSent) { + sendCancelStatus(); + cancelNotificationSent = true; + } + } + + @Override + public boolean isCanceled() { + return canceled || delegate.isCanceled(); + } } diff --git a/src/main/java/org/openrewrite/polyglot/RemoteProgressBarSender.java b/src/main/java/org/openrewrite/polyglot/RemoteProgressBarSender.java index 4e253c2..36bae59 100644 --- a/src/main/java/org/openrewrite/polyglot/RemoteProgressBarSender.java +++ b/src/main/java/org/openrewrite/polyglot/RemoteProgressBarSender.java @@ -28,6 +28,7 @@ public class RemoteProgressBarSender implements ProgressBar { private DatagramSocket socket; private InetAddress address; private int port; + private volatile boolean canceled = false; public RemoteProgressBarSender(int port) { this(null, port); @@ -41,12 +42,16 @@ public RemoteProgressBarSender(@Nullable InetAddress address, int port) { this.socket = new DatagramSocket(); this.port = port; this.address = address == null ? InetAddress.getByName(localhost) : address; + + // Set socket to non-blocking mode for checking cancel messages + this.socket.setSoTimeout(1); // 1ms timeout for non-blocking receive } catch (UnknownHostException | SocketException e) { if ("host.docker.internal".equals(localhost)) { try { this.address = InetAddress.getByName("localhost"); this.port = port; this.socket = new DatagramSocket(); + this.socket.setSoTimeout(1); // 1ms timeout for non-blocking } catch (UnknownHostException | SocketException ex) { throw new UncheckedIOException(ex); } @@ -94,6 +99,10 @@ public void throwRemote(RemoteException ex) { private void send(Type type, @Nullable String message) { try { + // Check for any pending cancel messages before sending + drainCancelMessages(); + + // Send the message for (byte[] packet : RemoteProgressMessage.toPackets(type, message)) { socket.send(new DatagramPacket(packet, packet.length, address, port)); } @@ -103,4 +112,50 @@ private void send(Type type, @Nullable String message) { throw new UncheckedIOException(e); } } + + /** + * Non-blocking check for any pending cancel messages. + * Drains all available cancel messages from the socket buffer. + */ + private void drainCancelMessages() { + if (canceled) { + return; // Already canceled, no need to check + } + + try { + byte[] buf = new byte[128]; + DatagramPacket packet = new DatagramPacket(buf, buf.length); + + // Keep reading while there are messages available (non-blocking due to timeout=0) + while (true) { + try { + socket.receive(packet); + + // Parse the received packet to check if it's a cancel message + String received = new String(packet.getData(), 0, packet.getLength()); + if (received.contains("CANCEL:true")) { + canceled = true; + // Continue draining to clear the buffer + } + } catch (SocketTimeoutException e) { + // No more messages available, done draining + break; + } + } + } catch (IOException ignored) { + // Ignore other IO exceptions during cancel check + } + } + + @Override + public void setCanceled(boolean canceled) { + this.canceled = canceled; + } + + @Override + public boolean isCanceled() { + // Also check for pending cancel messages when queried + drainCancelMessages(); + return canceled; + } } diff --git a/src/test/java/org/openrewrite/polyglot/RemoteProgressBarCancelTest.java b/src/test/java/org/openrewrite/polyglot/RemoteProgressBarCancelTest.java new file mode 100644 index 0000000..a5032b7 --- /dev/null +++ b/src/test/java/org/openrewrite/polyglot/RemoteProgressBarCancelTest.java @@ -0,0 +1,221 @@ +/* + * Copyright 2025 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.polyglot; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.net.DatagramPacket; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; + +public class RemoteProgressBarCancelTest { + + @Test + @Timeout(5) + public void testCancelPropagation() throws InterruptedException { + // Create a mock delegate progress bar + TestProgressBar delegate = new TestProgressBar(); + + // Create receiver + RemoteProgressBarReceiver receiver = new RemoteProgressBarReceiver(delegate); + int port = receiver.getPort(); + + // Create sender connected to receiver + RemoteProgressBarSender sender = new RemoteProgressBarSender(port); + + try { + // Initially neither should be canceled + assertFalse(sender.isCanceled()); + assertFalse(receiver.isCanceled()); + assertFalse(delegate.isCanceled()); + + // Send a progress message to establish communication + sender.step(); + Thread.sleep(100); // Give time for message to be received + + // Now cancel the delegate + delegate.setCanceled(true); + assertTrue(delegate.isCanceled()); + assertTrue(receiver.isCanceled()); // Receiver should reflect delegate state + + // Send another message - this should trigger cancel status to be sent back + sender.step(); + Thread.sleep(100); // Give time for cancel status to propagate back + + // Sender should now be canceled + assertTrue(sender.isCanceled()); + + } finally { + sender.close(); + receiver.close(); + } + } + + @Test + @Timeout(5) + public void testCancelPropagationFromReceiver() throws InterruptedException { + TestProgressBar delegate = new TestProgressBar(); + RemoteProgressBarReceiver receiver = new RemoteProgressBarReceiver(delegate); + int port = receiver.getPort(); + RemoteProgressBarSender sender = new RemoteProgressBarSender(port); + + try { + // Send initial message + sender.step(); + Thread.sleep(100); + + // Cancel directly on receiver + receiver.setCanceled(true); + assertTrue(receiver.isCanceled()); + + // Send another message to trigger cancel propagation + sender.setMax(100); + Thread.sleep(100); + + // Sender should be canceled + assertTrue(sender.isCanceled()); + + } finally { + sender.close(); + receiver.close(); + } + } + + @Test + @Timeout(10) + public void testDelayedCancelCheck() throws InterruptedException { + TestProgressBar delegate = new TestProgressBar(); + RemoteProgressBarReceiver receiver = new RemoteProgressBarReceiver(delegate); + int port = receiver.getPort(); + RemoteProgressBarSender sender = new RemoteProgressBarSender(port); + + try { + // Send initial message to establish communication + sender.step(); + Thread.sleep(100); + + // Cancel on receiver side + receiver.setCanceled(true); + + // Wait 5 seconds before sender sends another message or checks + Thread.sleep(5000); + + // Now check if sender picks up the cancel + // Either by sending a message (which calls drainCancelMessages) + sender.step(); + Thread.sleep(100); + + // Or by calling isCanceled (which also calls drainCancelMessages) + assertTrue(sender.isCanceled(), "Sender should detect cancel even after 5 second delay"); + + } finally { + sender.close(); + receiver.close(); + } + } + + @Test + @Timeout(5) + public void testCancelIsOneWayLatch() throws InterruptedException { + TestProgressBar delegate = new TestProgressBar(); + RemoteProgressBarReceiver receiver = new RemoteProgressBarReceiver(delegate); + int port = receiver.getPort(); + RemoteProgressBarSender sender = new RemoteProgressBarSender(port); + + try { + // Cancel and propagate + delegate.setCanceled(true); + sender.step(); + Thread.sleep(100); + assertTrue(sender.isCanceled()); + + // Try to uncancel - should not work + delegate.setCanceled(false); + receiver.setCanceled(false); + + // Send message + sender.step(); + Thread.sleep(100); + + // Sender should still be canceled (one-way latch) + assertTrue(sender.isCanceled()); + + } finally { + sender.close(); + receiver.close(); + } + } + + + // Test implementation of ProgressBar for testing + static class TestProgressBar implements ProgressBar { + private volatile boolean canceled = false; + private int steps = 0; + private int max = 0; + + @Override + public void intermediateResult(String message) { + // No-op + } + + @Override + public void finish(String message) { + // No-op + } + + @Override + public void close() { + // No-op + } + + @Override + public void step() { + steps++; + } + + @Override + public ProgressBar setExtraMessage(String extraMessage) { + return this; + } + + @Override + public ProgressBar setMax(int max) { + this.max = max; + return this; + } + + @Override + public void setCanceled(boolean canceled) { + this.canceled = canceled; + } + + @Override + public boolean isCanceled() { + return canceled; + } + + public int getSteps() { + return steps; + } + + public int getMax() { + return max; + } + } +} \ No newline at end of file