Skip to content

Commit 53a1e15

Browse files
committed
Supports Asynchronous Runs in Interactive Beam
1 parent f41cbde commit 53a1e15

File tree

6 files changed

+1048
-2
lines changed

6 files changed

+1048
-2
lines changed

sdks/python/apache_beam/runners/interactive/interactive_beam.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
5858
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
5959
from apache_beam.runners.interactive.options import interactive_options
60+
from apache_beam.runners.interactive.recording_manager import AsyncComputationResult
6061
from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
6162
from apache_beam.runners.interactive.utils import elements_to_df
6263
from apache_beam.runners.interactive.utils import find_pcoll_name
@@ -1012,6 +1013,90 @@ def as_pcollection(pcoll_or_df):
10121013
return result_tuple
10131014

10141015

1016+
@progress_indicated
1017+
def compute(
1018+
*pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection],
1019+
wait_for_inputs: bool = True,
1020+
blocking: bool = False,
1021+
runner=None,
1022+
options=None,
1023+
force_compute=False,
1024+
) -> Optional[AsyncComputationResult]:
1025+
"""Computes the given PCollections, potentially asynchronously.
1026+
1027+
Args:
1028+
*pcolls: PCollections to compute. Can be a single PCollection, an iterable
1029+
of PCollections, or a dictionary with PCollections as values.
1030+
wait_for_inputs: Whether to wait until the asynchronous dependencies are
1031+
computed. Setting this to False allows to immediately schedule the
1032+
computation, but also potentially results in running the same pipeline
1033+
stages multiple times.
1034+
blocking: If False, the computation will run in non-blocking fashion. In
1035+
Colab/IPython environment this mode will also provide the controls for the
1036+
running pipeline. If True, the computation will block until the pipeline
1037+
is done.
1038+
runner: (optional) the runner with which to compute the results.
1039+
options: (optional) any additional pipeline options to use to compute the
1040+
results.
1041+
force_compute: (optional) if True, forces recomputation rather than using
1042+
cached PCollections.
1043+
1044+
Returns:
1045+
An AsyncComputationResult object if blocking is False, otherwise None.
1046+
"""
1047+
flatten_pcolls = []
1048+
for pcoll_container in pcolls:
1049+
if isinstance(pcoll_container, dict):
1050+
flatten_pcolls.extend(pcoll_container.values())
1051+
elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)):
1052+
flatten_pcolls.append(pcoll_container)
1053+
else:
1054+
try:
1055+
flatten_pcolls.extend(iter(pcoll_container))
1056+
except TypeError:
1057+
raise ValueError(
1058+
f'The given pcoll {pcoll_container} is not a dict, an iterable or '
1059+
'a PCollection.'
1060+
)
1061+
1062+
pcolls_set = set()
1063+
for pcoll in flatten_pcolls:
1064+
if isinstance(pcoll, DeferredBase):
1065+
pcoll, _ = deferred_df_to_pcollection(pcoll)
1066+
watch({f'anonymous_pcollection_{id(pcoll)}': pcoll})
1067+
assert isinstance(
1068+
pcoll, beam.pvalue.PCollection
1069+
), f'{pcoll} is not an apache_beam.pvalue.PCollection.'
1070+
pcolls_set.add(pcoll)
1071+
1072+
if not pcolls_set:
1073+
_LOGGER.info('No PCollections to compute.')
1074+
return None
1075+
1076+
pcoll_pipeline = next(iter(pcolls_set)).pipeline
1077+
user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline)
1078+
if not user_pipeline:
1079+
watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline})
1080+
user_pipeline = pcoll_pipeline
1081+
1082+
for pcoll in pcolls_set:
1083+
if pcoll.pipeline is not user_pipeline:
1084+
raise ValueError('All PCollections must belong to the same pipeline.')
1085+
1086+
recording_manager = ie.current_env().get_recording_manager(
1087+
user_pipeline, create_if_absent=True
1088+
)
1089+
1090+
return recording_manager.compute_async(
1091+
pcolls_set,
1092+
wait_for_inputs=wait_for_inputs,
1093+
blocking=blocking,
1094+
runner=runner,
1095+
options=options,
1096+
force_compute=force_compute,
1097+
)
1098+
1099+
10151100
@progress_indicated
10161101
def show_graph(pipeline):
10171102
"""Shows the current pipeline shape of a given Beam pipeline as a DAG.

sdks/python/apache_beam/runners/interactive/interactive_beam_test.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import unittest
2626
from typing import NamedTuple
2727
from unittest.mock import patch
28+
from concurrent.futures import Future, TimeoutError
2829

2930
import apache_beam as beam
3031
from apache_beam import dataframe as frames
@@ -36,6 +37,7 @@
3637
from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager
3738
from apache_beam.runners.interactive.dataproc.types import ClusterMetadata
3839
from apache_beam.runners.interactive.options.capture_limiters import Limiter
40+
from apache_beam.runners.interactive.recording_manager import AsyncComputationResult
3941
from apache_beam.runners.interactive.testing.mock_env import isolated_env
4042
from apache_beam.runners.runner import PipelineState
4143
from apache_beam.testing.test_stream import TestStream
@@ -671,5 +673,243 @@ def test_default_value_for_invalid_worker_number(self):
671673
self.assertEqual(meta.num_workers, 2)
672674

673675

676+
@isolated_env
677+
class InteractiveBeamComputeTest(unittest.TestCase):
678+
679+
def setUp(self):
680+
self.env = ie.current_env()
681+
self.env._is_in_ipython = False # Default to non-IPython
682+
683+
def test_compute_blocking(self):
684+
p = beam.Pipeline(ir.InteractiveRunner())
685+
data = list(range(10))
686+
pcoll = p | 'Create' >> beam.Create(data)
687+
ib.watch(locals())
688+
self.env.track_user_pipelines()
689+
690+
result = ib.compute(pcoll, blocking=True)
691+
self.assertIsNone(result) # Blocking returns None
692+
self.assertTrue(pcoll in self.env.computed_pcollections)
693+
collected = ib.collect(pcoll, raw_records=True)
694+
self.assertEqual(collected, data)
695+
696+
def test_compute_non_blocking(self):
697+
p = beam.Pipeline(ir.InteractiveRunner())
698+
data = list(range(5))
699+
pcoll = p | 'Create' >> beam.Create(data)
700+
ib.watch(locals())
701+
self.env.track_user_pipelines()
702+
703+
async_result = ib.compute(pcoll, blocking=False)
704+
self.assertIsInstance(async_result, AsyncComputationResult)
705+
706+
pipeline_result = async_result.result(timeout=60)
707+
self.assertTrue(async_result.done())
708+
self.assertIsNone(async_result.exception())
709+
self.assertEqual(pipeline_result.state, PipelineState.DONE)
710+
self.assertTrue(pcoll in self.env.computed_pcollections)
711+
collected = ib.collect(pcoll, raw_records=True)
712+
self.assertEqual(collected, data)
713+
714+
def test_compute_with_list_input(self):
715+
p = beam.Pipeline(ir.InteractiveRunner())
716+
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
717+
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
718+
ib.watch(locals())
719+
self.env.track_user_pipelines()
720+
721+
ib.compute([pcoll1, pcoll2], blocking=True)
722+
self.assertTrue(pcoll1 in self.env.computed_pcollections)
723+
self.assertTrue(pcoll2 in self.env.computed_pcollections)
724+
725+
def test_compute_with_dict_input(self):
726+
p = beam.Pipeline(ir.InteractiveRunner())
727+
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
728+
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
729+
ib.watch(locals())
730+
self.env.track_user_pipelines()
731+
732+
ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True)
733+
self.assertTrue(pcoll1 in self.env.computed_pcollections)
734+
self.assertTrue(pcoll2 in self.env.computed_pcollections)
735+
736+
def test_compute_empty_input(self):
737+
result = ib.compute([], blocking=True)
738+
self.assertIsNone(result)
739+
result_async = ib.compute([], blocking=False)
740+
self.assertIsNone(result_async)
741+
742+
def test_compute_force_recompute(self):
743+
p = beam.Pipeline(ir.InteractiveRunner())
744+
pcoll = p | 'Create' >> beam.Create([1, 2, 3])
745+
ib.watch(locals())
746+
self.env.track_user_pipelines()
747+
748+
ib.compute(pcoll, blocking=True)
749+
self.assertTrue(pcoll in self.env.computed_pcollections)
750+
751+
# Mock evict_computed_pcollections to check if it's called
752+
with patch.object(self.env, 'evict_computed_pcollections') as mock_evict:
753+
ib.compute(pcoll, blocking=True, force_compute=True)
754+
mock_evict.assert_called_once_with(p)
755+
self.assertTrue(pcoll in self.env.computed_pcollections)
756+
757+
def test_compute_non_blocking_exception(self):
758+
p = beam.Pipeline(ir.InteractiveRunner())
759+
760+
def raise_error(elem):
761+
raise ValueError('Test Error')
762+
763+
pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error)
764+
ib.watch(locals())
765+
self.env.track_user_pipelines()
766+
767+
async_result = ib.compute(pcoll, blocking=False)
768+
self.assertIsInstance(async_result, AsyncComputationResult)
769+
770+
with self.assertRaises(ValueError):
771+
async_result.result(timeout=60)
772+
773+
self.assertTrue(async_result.done())
774+
self.assertIsInstance(async_result.exception(), ValueError)
775+
self.assertFalse(pcoll in self.env.computed_pcollections)
776+
777+
@patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
778+
@patch('apache_beam.runners.interactive.recording_manager.display')
779+
@patch('ipywidgets.Button')
780+
@patch('ipywidgets.FloatProgress')
781+
@patch('ipywidgets.Output')
782+
@patch('ipywidgets.HBox')
783+
@patch('ipywidgets.VBox')
784+
def test_compute_non_blocking_ipython_widgets(
785+
self,
786+
mock_vbox,
787+
mock_hbox,
788+
mock_output,
789+
mock_progress,
790+
mock_button,
791+
mock_display,
792+
):
793+
self.env._is_in_ipython = True
794+
p = beam.Pipeline(ir.InteractiveRunner())
795+
pcoll = p | 'Create' >> beam.Create(range(3))
796+
ib.watch(locals())
797+
self.env.track_user_pipelines()
798+
799+
async_result = ib.compute(pcoll, blocking=False)
800+
self.assertIsNotNone(async_result)
801+
mock_button.assert_called_once_with(description='Cancel')
802+
mock_progress.assert_called_once()
803+
mock_output.assert_called_once()
804+
mock_hbox.assert_called_once()
805+
mock_vbox.assert_called_once()
806+
mock_display.assert_called_once()
807+
async_result.result(timeout=60) # Let it finish
808+
809+
def test_compute_dependency_wait_true(self):
810+
p = beam.Pipeline(ir.InteractiveRunner())
811+
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
812+
pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
813+
ib.watch(locals())
814+
self.env.track_user_pipelines()
815+
816+
rm = self.env.get_recording_manager(p)
817+
818+
# Start pcoll1 computation
819+
async_res1 = ib.compute(pcoll1, blocking=False)
820+
self.assertTrue(self.env.is_pcollection_computing(pcoll1))
821+
822+
# Spy on _wait_for_dependencies
823+
with patch.object(
824+
rm, '_wait_for_dependencies', wraps=rm._wait_for_dependencies
825+
) as spy_wait:
826+
async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True)
827+
828+
# Check that wait_for_dependencies was called for pcoll2
829+
spy_wait.assert_called_with({pcoll2}, async_res2)
830+
831+
# Let pcoll1 finish
832+
async_res1.result(timeout=60)
833+
self.assertTrue(pcoll1 in self.env.computed_pcollections)
834+
self.assertFalse(self.env.is_pcollection_computing(pcoll1))
835+
836+
# pcoll2 should now run and complete
837+
async_res2.result(timeout=60)
838+
self.assertTrue(pcoll2 in self.env.computed_pcollections)
839+
840+
@patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing')
841+
def test_compute_dependency_wait_false(self, mock_is_computing):
842+
p = beam.Pipeline(ir.InteractiveRunner())
843+
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
844+
pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2)
845+
ib.watch(locals())
846+
self.env.track_user_pipelines()
847+
848+
rm = self.env.get_recording_manager(p)
849+
850+
# Pretend pcoll1 is computing
851+
mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1
852+
853+
with patch.object(
854+
rm, '_execute_pipeline_fragment', wraps=rm._execute_pipeline_fragment
855+
) as spy_execute:
856+
async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False)
857+
async_res2.result(timeout=60)
858+
859+
# Assert that execute was called for pcoll2 without waiting
860+
spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY)
861+
self.assertTrue(pcoll2 in self.env.computed_pcollections)
862+
863+
def test_async_computation_result_cancel(self):
864+
p = beam.Pipeline(ir.InteractiveRunner())
865+
# A stream that never finishes to test cancellation
866+
pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100))
867+
ib.watch(locals())
868+
self.env.track_user_pipelines()
869+
870+
async_result = ib.compute(pcoll, blocking=False)
871+
self.assertIsInstance(async_result, AsyncComputationResult)
872+
873+
# Give it a moment to start
874+
time.sleep(0.1)
875+
876+
# Mock the pipeline result's cancel method
877+
mock_pipeline_result = MagicMock()
878+
mock_pipeline_result.state = PipelineState.RUNNING
879+
async_result.set_pipeline_result(mock_pipeline_result)
880+
881+
self.assertTrue(async_result.cancel())
882+
mock_pipeline_result.cancel.assert_called_once()
883+
884+
# The future should be cancelled eventually by the runner
885+
# This part is hard to test without deeper runner integration
886+
with self.assertRaises(TimeoutError):
887+
async_result.result(timeout=1) # It should not complete successfully
888+
889+
def test_compute_multiple_async(self):
890+
p = beam.Pipeline(ir.InteractiveRunner())
891+
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
892+
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
893+
pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)
894+
ib.watch(locals())
895+
self.env.track_user_pipelines()
896+
897+
res1 = ib.compute(pcoll1, blocking=False)
898+
res2 = ib.compute(pcoll2, blocking=False)
899+
res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1
900+
901+
self.assertIsNotNone(res1)
902+
self.assertIsNotNone(res2)
903+
self.assertIsNotNone(res3)
904+
905+
res1.result(timeout=60)
906+
res2.result(timeout=60)
907+
res3.result(timeout=60)
908+
909+
self.assertTrue(pcoll1 in self.env.computed_pcollections)
910+
self.assertTrue(pcoll2 in self.env.computed_pcollections)
911+
self.assertTrue(pcoll3 in self.env.computed_pcollections)
912+
913+
674914
if __name__ == '__main__':
675915
unittest.main()

sdks/python/apache_beam/runners/interactive/interactive_environment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def __init__(self):
175175
# Tracks the computation completeness of PCollections. PCollections tracked
176176
# here don't need to be re-computed when data introspection is needed.
177177
self._computed_pcolls = set()
178+
179+
self._computing_pcolls = set()
180+
178181
# Always watch __main__ module.
179182
self.watch('__main__')
180183
# Check if [interactive] dependencies are installed.
@@ -720,3 +723,19 @@ def _get_gcs_cache_dir(self, pipeline, cache_dir):
720723
bucket_name = cache_dir_path.parts[1]
721724
assert_bucket_exists(bucket_name)
722725
return 'gs://{}/{}'.format('/'.join(cache_dir_path.parts[1:]), id(pipeline))
726+
727+
@property
728+
def computing_pcollections(self):
729+
return self._computing_pcolls
730+
731+
def mark_pcollection_computing(self, pcolls):
732+
"""Marks the given pcolls as currently being computed."""
733+
self._computing_pcolls.update(pcolls)
734+
735+
def unmark_pcollection_computing(self, pcolls):
736+
"""Removes the given pcolls from the computing set."""
737+
self._computing_pcolls.difference_update(pcolls)
738+
739+
def is_pcollection_computing(self, pcoll):
740+
"""Checks if the given pcollection is currently being computed."""
741+
return pcoll in self._computing_pcolls

0 commit comments

Comments
 (0)