|
25 | 25 | import unittest |
26 | 26 | from typing import NamedTuple |
27 | 27 | from unittest.mock import patch |
| 28 | +from concurrent.futures import Future, TimeoutError |
28 | 29 |
|
29 | 30 | import apache_beam as beam |
30 | 31 | from apache_beam import dataframe as frames |
|
36 | 37 | from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager |
37 | 38 | from apache_beam.runners.interactive.dataproc.types import ClusterMetadata |
38 | 39 | from apache_beam.runners.interactive.options.capture_limiters import Limiter |
| 40 | +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult |
39 | 41 | from apache_beam.runners.interactive.testing.mock_env import isolated_env |
40 | 42 | from apache_beam.runners.runner import PipelineState |
41 | 43 | from apache_beam.testing.test_stream import TestStream |
@@ -671,5 +673,243 @@ def test_default_value_for_invalid_worker_number(self): |
671 | 673 | self.assertEqual(meta.num_workers, 2) |
672 | 674 |
|
673 | 675 |
|
| 676 | +@isolated_env |
| 677 | +class InteractiveBeamComputeTest(unittest.TestCase): |
| 678 | + |
| 679 | + def setUp(self): |
| 680 | + self.env = ie.current_env() |
| 681 | + self.env._is_in_ipython = False # Default to non-IPython |
| 682 | + |
| 683 | + def test_compute_blocking(self): |
| 684 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 685 | + data = list(range(10)) |
| 686 | + pcoll = p | 'Create' >> beam.Create(data) |
| 687 | + ib.watch(locals()) |
| 688 | + self.env.track_user_pipelines() |
| 689 | + |
| 690 | + result = ib.compute(pcoll, blocking=True) |
| 691 | + self.assertIsNone(result) # Blocking returns None |
| 692 | + self.assertTrue(pcoll in self.env.computed_pcollections) |
| 693 | + collected = ib.collect(pcoll, raw_records=True) |
| 694 | + self.assertEqual(collected, data) |
| 695 | + |
| 696 | + def test_compute_non_blocking(self): |
| 697 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 698 | + data = list(range(5)) |
| 699 | + pcoll = p | 'Create' >> beam.Create(data) |
| 700 | + ib.watch(locals()) |
| 701 | + self.env.track_user_pipelines() |
| 702 | + |
| 703 | + async_result = ib.compute(pcoll, blocking=False) |
| 704 | + self.assertIsInstance(async_result, AsyncComputationResult) |
| 705 | + |
| 706 | + pipeline_result = async_result.result(timeout=60) |
| 707 | + self.assertTrue(async_result.done()) |
| 708 | + self.assertIsNone(async_result.exception()) |
| 709 | + self.assertEqual(pipeline_result.state, PipelineState.DONE) |
| 710 | + self.assertTrue(pcoll in self.env.computed_pcollections) |
| 711 | + collected = ib.collect(pcoll, raw_records=True) |
| 712 | + self.assertEqual(collected, data) |
| 713 | + |
| 714 | + def test_compute_with_list_input(self): |
| 715 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 716 | + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) |
| 717 | + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) |
| 718 | + ib.watch(locals()) |
| 719 | + self.env.track_user_pipelines() |
| 720 | + |
| 721 | + ib.compute([pcoll1, pcoll2], blocking=True) |
| 722 | + self.assertTrue(pcoll1 in self.env.computed_pcollections) |
| 723 | + self.assertTrue(pcoll2 in self.env.computed_pcollections) |
| 724 | + |
| 725 | + def test_compute_with_dict_input(self): |
| 726 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 727 | + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) |
| 728 | + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) |
| 729 | + ib.watch(locals()) |
| 730 | + self.env.track_user_pipelines() |
| 731 | + |
| 732 | + ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True) |
| 733 | + self.assertTrue(pcoll1 in self.env.computed_pcollections) |
| 734 | + self.assertTrue(pcoll2 in self.env.computed_pcollections) |
| 735 | + |
| 736 | + def test_compute_empty_input(self): |
| 737 | + result = ib.compute([], blocking=True) |
| 738 | + self.assertIsNone(result) |
| 739 | + result_async = ib.compute([], blocking=False) |
| 740 | + self.assertIsNone(result_async) |
| 741 | + |
| 742 | + def test_compute_force_recompute(self): |
| 743 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 744 | + pcoll = p | 'Create' >> beam.Create([1, 2, 3]) |
| 745 | + ib.watch(locals()) |
| 746 | + self.env.track_user_pipelines() |
| 747 | + |
| 748 | + ib.compute(pcoll, blocking=True) |
| 749 | + self.assertTrue(pcoll in self.env.computed_pcollections) |
| 750 | + |
| 751 | + # Mock evict_computed_pcollections to check if it's called |
| 752 | + with patch.object(self.env, 'evict_computed_pcollections') as mock_evict: |
| 753 | + ib.compute(pcoll, blocking=True, force_compute=True) |
| 754 | + mock_evict.assert_called_once_with(p) |
| 755 | + self.assertTrue(pcoll in self.env.computed_pcollections) |
| 756 | + |
| 757 | + def test_compute_non_blocking_exception(self): |
| 758 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 759 | + |
| 760 | + def raise_error(elem): |
| 761 | + raise ValueError('Test Error') |
| 762 | + |
| 763 | + pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error) |
| 764 | + ib.watch(locals()) |
| 765 | + self.env.track_user_pipelines() |
| 766 | + |
| 767 | + async_result = ib.compute(pcoll, blocking=False) |
| 768 | + self.assertIsInstance(async_result, AsyncComputationResult) |
| 769 | + |
| 770 | + with self.assertRaises(ValueError): |
| 771 | + async_result.result(timeout=60) |
| 772 | + |
| 773 | + self.assertTrue(async_result.done()) |
| 774 | + self.assertIsInstance(async_result.exception(), ValueError) |
| 775 | + self.assertFalse(pcoll in self.env.computed_pcollections) |
| 776 | + |
| 777 | + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) |
| 778 | + @patch('apache_beam.runners.interactive.recording_manager.display') |
| 779 | + @patch('ipywidgets.Button') |
| 780 | + @patch('ipywidgets.FloatProgress') |
| 781 | + @patch('ipywidgets.Output') |
| 782 | + @patch('ipywidgets.HBox') |
| 783 | + @patch('ipywidgets.VBox') |
| 784 | + def test_compute_non_blocking_ipython_widgets( |
| 785 | + self, |
| 786 | + mock_vbox, |
| 787 | + mock_hbox, |
| 788 | + mock_output, |
| 789 | + mock_progress, |
| 790 | + mock_button, |
| 791 | + mock_display, |
| 792 | + ): |
| 793 | + self.env._is_in_ipython = True |
| 794 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 795 | + pcoll = p | 'Create' >> beam.Create(range(3)) |
| 796 | + ib.watch(locals()) |
| 797 | + self.env.track_user_pipelines() |
| 798 | + |
| 799 | + async_result = ib.compute(pcoll, blocking=False) |
| 800 | + self.assertIsNotNone(async_result) |
| 801 | + mock_button.assert_called_once_with(description='Cancel') |
| 802 | + mock_progress.assert_called_once() |
| 803 | + mock_output.assert_called_once() |
| 804 | + mock_hbox.assert_called_once() |
| 805 | + mock_vbox.assert_called_once() |
| 806 | + mock_display.assert_called_once() |
| 807 | + async_result.result(timeout=60) # Let it finish |
| 808 | + |
| 809 | + def test_compute_dependency_wait_true(self): |
| 810 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 811 | + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) |
| 812 | + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) |
| 813 | + ib.watch(locals()) |
| 814 | + self.env.track_user_pipelines() |
| 815 | + |
| 816 | + rm = self.env.get_recording_manager(p) |
| 817 | + |
| 818 | + # Start pcoll1 computation |
| 819 | + async_res1 = ib.compute(pcoll1, blocking=False) |
| 820 | + self.assertTrue(self.env.is_pcollection_computing(pcoll1)) |
| 821 | + |
| 822 | + # Spy on _wait_for_dependencies |
| 823 | + with patch.object( |
| 824 | + rm, '_wait_for_dependencies', wraps=rm._wait_for_dependencies |
| 825 | + ) as spy_wait: |
| 826 | + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True) |
| 827 | + |
| 828 | + # Check that wait_for_dependencies was called for pcoll2 |
| 829 | + spy_wait.assert_called_with({pcoll2}, async_res2) |
| 830 | + |
| 831 | + # Let pcoll1 finish |
| 832 | + async_res1.result(timeout=60) |
| 833 | + self.assertTrue(pcoll1 in self.env.computed_pcollections) |
| 834 | + self.assertFalse(self.env.is_pcollection_computing(pcoll1)) |
| 835 | + |
| 836 | + # pcoll2 should now run and complete |
| 837 | + async_res2.result(timeout=60) |
| 838 | + self.assertTrue(pcoll2 in self.env.computed_pcollections) |
| 839 | + |
| 840 | + @patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing') |
| 841 | + def test_compute_dependency_wait_false(self, mock_is_computing): |
| 842 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 843 | + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) |
| 844 | + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) |
| 845 | + ib.watch(locals()) |
| 846 | + self.env.track_user_pipelines() |
| 847 | + |
| 848 | + rm = self.env.get_recording_manager(p) |
| 849 | + |
| 850 | + # Pretend pcoll1 is computing |
| 851 | + mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1 |
| 852 | + |
| 853 | + with patch.object( |
| 854 | + rm, '_execute_pipeline_fragment', wraps=rm._execute_pipeline_fragment |
| 855 | + ) as spy_execute: |
| 856 | + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False) |
| 857 | + async_res2.result(timeout=60) |
| 858 | + |
| 859 | + # Assert that execute was called for pcoll2 without waiting |
| 860 | + spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY) |
| 861 | + self.assertTrue(pcoll2 in self.env.computed_pcollections) |
| 862 | + |
| 863 | + def test_async_computation_result_cancel(self): |
| 864 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 865 | + # A stream that never finishes to test cancellation |
| 866 | + pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100)) |
| 867 | + ib.watch(locals()) |
| 868 | + self.env.track_user_pipelines() |
| 869 | + |
| 870 | + async_result = ib.compute(pcoll, blocking=False) |
| 871 | + self.assertIsInstance(async_result, AsyncComputationResult) |
| 872 | + |
| 873 | + # Give it a moment to start |
| 874 | + time.sleep(0.1) |
| 875 | + |
| 876 | + # Mock the pipeline result's cancel method |
| 877 | + mock_pipeline_result = MagicMock() |
| 878 | + mock_pipeline_result.state = PipelineState.RUNNING |
| 879 | + async_result.set_pipeline_result(mock_pipeline_result) |
| 880 | + |
| 881 | + self.assertTrue(async_result.cancel()) |
| 882 | + mock_pipeline_result.cancel.assert_called_once() |
| 883 | + |
| 884 | + # The future should be cancelled eventually by the runner |
| 885 | + # This part is hard to test without deeper runner integration |
| 886 | + with self.assertRaises(TimeoutError): |
| 887 | + async_result.result(timeout=1) # It should not complete successfully |
| 888 | + |
| 889 | + def test_compute_multiple_async(self): |
| 890 | + p = beam.Pipeline(ir.InteractiveRunner()) |
| 891 | + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) |
| 892 | + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) |
| 893 | + pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) |
| 894 | + ib.watch(locals()) |
| 895 | + self.env.track_user_pipelines() |
| 896 | + |
| 897 | + res1 = ib.compute(pcoll1, blocking=False) |
| 898 | + res2 = ib.compute(pcoll2, blocking=False) |
| 899 | + res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1 |
| 900 | + |
| 901 | + self.assertIsNotNone(res1) |
| 902 | + self.assertIsNotNone(res2) |
| 903 | + self.assertIsNotNone(res3) |
| 904 | + |
| 905 | + res1.result(timeout=60) |
| 906 | + res2.result(timeout=60) |
| 907 | + res3.result(timeout=60) |
| 908 | + |
| 909 | + self.assertTrue(pcoll1 in self.env.computed_pcollections) |
| 910 | + self.assertTrue(pcoll2 in self.env.computed_pcollections) |
| 911 | + self.assertTrue(pcoll3 in self.env.computed_pcollections) |
| 912 | + |
| 913 | + |
674 | 914 | if __name__ == '__main__': |
675 | 915 | unittest.main() |
0 commit comments