Skip to content

Commit f22d2bc

Browse files
authored
Save Logical Type and Coder Registry on cloudpickle save main session (#36271)
* Save Logical Type Registry and Coder Registry on cloudpickle save main session fix naming * Track custom_urn set in logical type registry * Fix, add tests * Set save_main_session default to true for cloudpickle and introduce overwrite flag * Fix test; trigger postcommits * Fix test as Dataflow runner submission now staging a main session file
1 parent 8ff2c94 commit f22d2bc

File tree

14 files changed

+239
-87
lines changed

14 files changed

+239
-87
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run.",
3+
"pr": "36271",
34
"modification": 35
45
}
56

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 14
3+
"modification": 0
44
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 5
3+
"modification": 0
44
}

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
## Bugfixes
8686

8787
* Fixed FirestoreV1 Beam connectors allow configuring inconsistent project/database IDs between RPC requests and routing headers #36895 (Java) ([#36895](https://github.com/apache/beam/issues/36895)).
88+
Logical type and coder registry are saved for pipelines in the case of default pickler. This fixes a side effect of switching to cloudpickle as default pickler in Beam 2.65.0 (Python) ([#35738](https://github.com/apache/beam/issues/35738)).
8889

8990
## Known Issues
9091

sdks/python/apache_beam/coders/typecoders.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ def _register_coder_internal(
114114
typehint_coder_class: Type[coders.Coder]) -> None:
115115
self._coders[typehint_type] = typehint_coder_class
116116

117+
@staticmethod
118+
def _normalize_typehint_type(typehint_type):
119+
if typehint_type.__module__ == '__main__':
120+
# See https://github.com/apache/beam/issues/21541
121+
# TODO(robertwb): Remove once all runners are portable.
122+
return getattr(typehint_type, '__name__', str(typehint_type))
123+
return typehint_type
124+
117125
def register_coder(
118126
self, typehint_type: Any,
119127
typehint_coder_class: Type[coders.Coder]) -> None:
@@ -123,11 +131,8 @@ def register_coder(
123131
'Received %r instead.' % typehint_coder_class)
124132
if typehint_type not in self.custom_types:
125133
self.custom_types.append(typehint_type)
126-
if typehint_type.__module__ == '__main__':
127-
# See https://github.com/apache/beam/issues/21541
128-
# TODO(robertwb): Remove once all runners are portable.
129-
typehint_type = getattr(typehint_type, '__name__', str(typehint_type))
130-
self._register_coder_internal(typehint_type, typehint_coder_class)
134+
self._register_coder_internal(
135+
self._normalize_typehint_type(typehint_type), typehint_coder_class)
131136

132137
def get_coder(self, typehint: Any) -> coders.Coder:
133138
if typehint and typehint.__module__ == '__main__':
@@ -170,9 +175,15 @@ def get_coder(self, typehint: Any) -> coders.Coder:
170175
coder = self._fallback_coder
171176
return coder.from_type_hint(typehint, self)
172177

173-
def get_custom_type_coder_tuples(self, types):
178+
def get_custom_type_coder_tuples(self, types=None):
174179
"""Returns type/coder tuples for all custom types passed in."""
175-
return [(t, self._coders[t]) for t in types if t in self.custom_types]
180+
return [(t, self._coders[self._normalize_typehint_type(t)])
181+
for t in self.custom_types if (types is None or t in types)]
182+
183+
def load_custom_type_coder_tuples(self, type_coder):
184+
"""Load type/coder tuples into coder registry."""
185+
for t, c in type_coder:
186+
self.register_coder(t, c)
176187

177188
def verify_deterministic(self, key_coder, op_name, silent=True):
178189
if not key_coder.is_deterministic():

sdks/python/apache_beam/internal/cloudpickle_pickler.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,35 @@ def _lock_reducer(obj):
252252

253253

254254
def dump_session(file_path):
255-
# It is possible to dump session with cloudpickle. However, since references
256-
# are saved it should not be necessary. See https://s.apache.org/beam-picklers
257-
pass
255+
# Since References are saved (https://s.apache.org/beam-picklers), we only
256+
# dump supported Beam Registries (currently only logical type registry)
257+
from apache_beam.coders import typecoders
258+
from apache_beam.typehints import schemas
259+
260+
with _pickle_lock, open(file_path, 'wb') as file:
261+
coder_reg = typecoders.registry.get_custom_type_coder_tuples()
262+
logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
263+
264+
pickler = cloudpickle.CloudPickler(file)
265+
# TODO(https://github.com/apache/beam/issues/18500) add file system registry
266+
# once implemented
267+
pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg})
258268

259269

260270
def load_session(file_path):
261-
# It is possible to load_session with cloudpickle. However, since references
262-
# are saved it should not be necessary. See https://s.apache.org/beam-picklers
263-
pass
271+
from apache_beam.coders import typecoders
272+
from apache_beam.typehints import schemas
273+
274+
with _pickle_lock, open(file_path, 'rb') as file:
275+
registries = cloudpickle.load(file)
276+
if type(registries) != dict:
277+
raise ValueError(
278+
"Faled loading session: expected dict, got {}", type(registries))
279+
if "coder" in registries:
280+
typecoders.registry.load_custom_type_coder_tuples(registries["coder"])
281+
else:
282+
_LOGGER.warning('No coder registry found in saved session')
283+
if "logical_type" in registries:
284+
schemas.LogicalType._known_logical_types.load(registries["logical_type"])
285+
else:
286+
_LOGGER.warning('No logical type registry found in saved session')

sdks/python/apache_beam/internal/cloudpickle_pickler_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# pytype: skip-file
2121

2222
import os
23+
import tempfile
2324
import threading
2425
import types
2526
import unittest
@@ -31,6 +32,7 @@
3132
from apache_beam.internal import module_test
3233
from apache_beam.internal.cloudpickle_pickler import dumps
3334
from apache_beam.internal.cloudpickle_pickler import loads
35+
from apache_beam.typehints.schemas import LogicalTypeRegistry
3436
from apache_beam.utils import shared
3537

3638
GLOBAL_DICT_REF = module_test.GLOBAL_DICT
@@ -244,6 +246,24 @@ def sample_func():
244246
unpickled_filename = os.path.abspath(unpickled_code.co_filename)
245247
self.assertEqual(unpickled_filename, original_filename)
246248

249+
@mock.patch(
250+
"apache_beam.coders.typecoders.registry.load_custom_type_coder_tuples")
251+
@mock.patch(
252+
"apache_beam.typehints.schemas.LogicalType._known_logical_types.load")
253+
def test_dump_load_session(self, logicaltype_mock, coder_mock):
254+
session_file = 'pickled'
255+
256+
with tempfile.TemporaryDirectory() as tmp_dirname:
257+
pickled_session_file = os.path.join(tmp_dirname, session_file)
258+
beam_cloudpickle.dump_session(pickled_session_file)
259+
beam_cloudpickle.load_session(pickled_session_file)
260+
load_logical_types = logicaltype_mock.call_args.args
261+
load_coders = coder_mock.call_args.args
262+
self.assertEqual(len(load_logical_types), 1)
263+
self.assertEqual(len(load_coders), 1)
264+
self.assertTrue(isinstance(load_logical_types[0], LogicalTypeRegistry))
265+
self.assertTrue(isinstance(load_coders[0], list))
266+
247267

248268
if __name__ == '__main__':
249269
unittest.main()

sdks/python/apache_beam/internal/pickler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ def load_session(file_path):
9191
return desired_pickle_lib.load_session(file_path)
9292

9393

94+
def is_currently_dill():
95+
return desired_pickle_lib == dill_pickler
96+
97+
98+
def is_currently_cloudpickle():
99+
return desired_pickle_lib == cloudpickle_pickler
100+
101+
94102
def set_library(selected_library=DEFAULT_PICKLE_LIB):
95103
""" Sets pickle library that will be used. """
96104
global desired_pickle_lib
@@ -108,12 +116,11 @@ def set_library(selected_library=DEFAULT_PICKLE_LIB):
108116
"Pipeline option pickle_library=dill_unsafe is set, but dill is not "
109117
"installed. Install dill in job submission and runtime environments.")
110118

111-
is_currently_dill = (desired_pickle_lib == dill_pickler)
112119
dill_is_requested = (
113120
selected_library == USE_DILL or selected_library == USE_DILL_UNSAFE)
114121

115122
# If switching to or from dill, update the pickler hook overrides.
116-
if is_currently_dill != dill_is_requested:
123+
if is_currently_dill() != dill_is_requested:
117124
dill_pickler.override_pickler_hooks(selected_library == USE_DILL)
118125

119126
if dill_is_requested:

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@
6464
# Map defined with option names to flag names for boolean options
6565
# that have a destination(dest) in parser.add_argument() different
6666
# from the flag name and whose default value is `None`.
67-
_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
67+
_FLAG_THAT_SETS_FALSE_VALUE = {
68+
'use_public_ips': 'no_use_public_ips',
69+
'save_main_session': 'no_save_main_session'
70+
}
6871
# Set of options which should not be overriden when applying options from a
6972
# different language. This is relevant when using x-lang transforms where the
7073
# expansion service is started up with some pipeline options, and will
@@ -1672,14 +1675,23 @@ def _add_argparse_args(cls, parser):
16721675
choices=['cloudpickle', 'default', 'dill', 'dill_unsafe'])
16731676
parser.add_argument(
16741677
'--save_main_session',
1675-
default=False,
1678+
default=None,
16761679
action='store_true',
16771680
help=(
16781681
'Save the main session state so that pickled functions and classes '
16791682
'defined in __main__ (e.g. interactive session) can be unpickled. '
16801683
'Some workflows do not need the session state if for instance all '
16811684
'their functions/classes are defined in proper modules '
16821685
'(not __main__) and the modules are importable in the worker. '))
1686+
parser.add_argument(
1687+
'--no_save_main_session',
1688+
default=None,
1689+
action='store_false',
1690+
dest='save_main_session',
1691+
help=(
1692+
'Disable saving the main session state. It is enabled/disabled by'
1693+
'default for cloudpickle/dill pickler. See "save_main_session".'))
1694+
16831695
parser.add_argument(
16841696
'--sdk_location',
16851697
default='default',
@@ -1780,10 +1792,23 @@ def _add_argparse_args(cls, parser):
17801792
'If not specified, the default Maven Central repository will be '
17811793
'used.'))
17821794

1795+
def _handle_load_main_session(self, validator):
1796+
save_main_session = getattr(self, 'save_main_session')
1797+
if save_main_session is None:
1798+
# save_main_session default to False for dill, while default to true
1799+
# for cloudpickle
1800+
pickle_library = getattr(self, 'pickle_library')
1801+
if pickle_library in ['default', 'cloudpickle']:
1802+
setattr(self, 'save_main_session', True)
1803+
else:
1804+
setattr(self, 'save_main_session', False)
1805+
return []
1806+
17831807
def validate(self, validator):
17841808
errors = []
17851809
errors.extend(validator.validate_container_prebuilding_options(self))
17861810
errors.extend(validator.validate_pickle_library(self))
1811+
errors.extend(self._handle_load_main_session(validator))
17871812
return errors
17881813

17891814

sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_options
4343
from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_streaming_options
4444
from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
45+
from apache_beam.runners.internal import names
4546
from apache_beam.runners.runner import PipelineState
4647
from apache_beam.testing.extra_assertions import ExtraAssertionsMixin
4748
from apache_beam.testing.test_pipeline import TestPipeline
@@ -243,6 +244,18 @@ def test_create_runner(self):
243244
self.assertTrue(
244245
isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner))
245246

247+
@staticmethod
248+
def dependency_proto_from_main_session_file(serialized_path):
249+
return [
250+
beam_runner_api_pb2.ArtifactInformation(
251+
type_urn=common_urns.artifact_types.FILE.urn,
252+
type_payload=serialized_path,
253+
role_urn=common_urns.artifact_roles.STAGING_TO.urn,
254+
role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
255+
staged_name=names.PICKLED_MAIN_SESSION_FILE).SerializeToString(
256+
))
257+
]
258+
246259
def test_environment_override_translation_legacy_worker_harness_image(self):
247260
self.default_properties.append('--experiments=beam_fn_api')
248261
self.default_properties.append('--worker_harness_container_image=LEGACY')
@@ -256,17 +269,22 @@ def test_environment_override_translation_legacy_worker_harness_image(self):
256269
| 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
257270
| ptransform.GroupByKey())
258271

272+
actual = list(remote_runner.proto_pipeline.components.environments.values())
273+
self.assertEqual(len(actual), 1)
274+
actual = actual[0]
275+
file_path = actual.dependencies[0].type_payload
276+
# Dependency payload contains main_session from a transient temp directory
277+
# Use actual for expected value.
278+
main_session_dep = self.dependency_proto_from_main_session_file(file_path)
259279
self.assertEqual(
260-
list(remote_runner.proto_pipeline.components.environments.values()),
261-
[
262-
beam_runner_api_pb2.Environment(
263-
urn=common_urns.environments.DOCKER.urn,
264-
payload=beam_runner_api_pb2.DockerPayload(
265-
container_image='LEGACY').SerializeToString(),
266-
capabilities=environments.python_sdk_docker_capabilities(),
267-
dependencies=environments.python_sdk_dependencies(
268-
options=options))
269-
])
280+
actual,
281+
beam_runner_api_pb2.Environment(
282+
urn=common_urns.environments.DOCKER.urn,
283+
payload=beam_runner_api_pb2.DockerPayload(
284+
container_image='LEGACY').SerializeToString(),
285+
capabilities=environments.python_sdk_docker_capabilities(),
286+
dependencies=environments.python_sdk_dependencies(options=options) +
287+
main_session_dep))
270288

271289
def test_environment_override_translation_sdk_container_image(self):
272290
self.default_properties.append('--experiments=beam_fn_api')
@@ -281,17 +299,22 @@ def test_environment_override_translation_sdk_container_image(self):
281299
| 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
282300
| ptransform.GroupByKey())
283301

302+
actual = list(remote_runner.proto_pipeline.components.environments.values())
303+
self.assertEqual(len(actual), 1)
304+
actual = actual[0]
305+
file_path = actual.dependencies[0].type_payload
306+
# Dependency payload contains main_session from a transient temp directory
307+
# Use actual for expected value.
308+
main_session_dep = self.dependency_proto_from_main_session_file(file_path)
284309
self.assertEqual(
285-
list(remote_runner.proto_pipeline.components.environments.values()),
286-
[
287-
beam_runner_api_pb2.Environment(
288-
urn=common_urns.environments.DOCKER.urn,
289-
payload=beam_runner_api_pb2.DockerPayload(
290-
container_image='FOO').SerializeToString(),
291-
capabilities=environments.python_sdk_docker_capabilities(),
292-
dependencies=environments.python_sdk_dependencies(
293-
options=options))
294-
])
310+
actual,
311+
beam_runner_api_pb2.Environment(
312+
urn=common_urns.environments.DOCKER.urn,
313+
payload=beam_runner_api_pb2.DockerPayload(
314+
container_image='FOO').SerializeToString(),
315+
capabilities=environments.python_sdk_docker_capabilities(),
316+
dependencies=environments.python_sdk_dependencies(options=options) +
317+
main_session_dep))
295318

296319
def test_remote_runner_translation(self):
297320
remote_runner = DataflowRunner()

0 commit comments

Comments
 (0)