Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import dis
from enum import Enum
import functools
import hashlib
import inspect
import io
import itertools
import logging
Expand Down Expand Up @@ -176,6 +178,160 @@ class CloudPickleConfig:
_extract_code_globals_cache = weakref.WeakKeyDictionary()


def _generate_deterministic_id(cls):
"""Generate deterministic ID for a class that's consistent across processes."""

# Special handling for TypeVar objects
if isinstance(cls, typing.TypeVar):
# For TypeVars, use a different approach to generate ID
components = [
("type", "TypeVar"),
("name", cls.__name__),
("bound", str(cls.__bound__) if cls.__bound__ is not None else None),
(
"constraints",
str(cls.__constraints__) if cls.__constraints__ else None,
),
("covariant", cls.__covariant__),
("contravariant", cls.__contravariant__),
]

# Create a deterministic string representation
repr_components = []
for key, value in components:
repr_components.append(f"{key}:{repr(value)}")

# Sort for determinism
repr_components.sort()

# Create a deterministic hash
m = hashlib.sha256()
for component in repr_components:
m.update(component.encode("utf-8"))

return m.hexdigest()

components = []

# Basic class metadata
components.append(("__name__", cls.__name__))
components.append(("__module__", cls.__module__))

# Handle __qualname__ safely (TypeVar and some other objects might not have it)
if hasattr(cls, "__qualname__"):
components.append(("__qualname__", cls.__qualname__))

# Handle __bases__ safely (some objects might not have it)
if hasattr(cls, "__bases__"):
# Base classes (by name, not by identity)
base_names = []
for base in cls.__bases__:
if base is object:
base_names.append("object")
else:
# Handle bases that might not have __qualname__
base_name = base.__module__ + "." + base.__name__
if hasattr(base, "__qualname__"):
base_name = base.__module__ + "." + base.__qualname__
base_names.append(base_name)
components.append(("__bases__", tuple(base_names)))

# Class dictionary content - only if it's a regular class
if hasattr(cls, "__dict__"):
try:
cls_dict = _extract_class_dict(cls)
for key in sorted(cls_dict.keys()):
value = cls_dict[key]

if isinstance(value, types.FunctionType):
# For methods, include their code
code = value.__code__
components.append((f"{key}.__code__", code.co_code))
components.append((f"{key}.__code__.co_consts", code.co_consts))
components.append((f"{key}.__code__.co_names", code.co_names))

# Include closure values for primitive types
if value.__closure__:
closure_values = []
for cell in value.__closure__:
try:
cell_value = cell.cell_contents
if isinstance(cell_value, (int, float, str, bool, type(None))):
closure_values.append(str(cell_value))
else:
# For non-primitive types, include type name and id
closure_values.append(f"<{type(cell_value).__name__}>")
except ValueError:
closure_values.append("<empty>")
components.append((f"{key}.__closure__", tuple(closure_values)))

elif isinstance(value, (int, float, str, bool, type(None))):
# For primitive types, include their values
components.append((key, value))

else:
# For other types, include type information
components.append((key, f"<{type(value).__name__}>"))
except (TypeError, AttributeError):
# Some objects might not support dictionary extraction
components.append(("__dict_extraction_failed__", True))

# Source location information - only try for regular classes
if isinstance(cls, type):
try:
source_info = []
if hasattr(cls, "__module__") and cls.__module__ != "__main__":
# For classes in modules, include module path
module = sys.modules.get(cls.__module__)
if module and hasattr(module, "__file__"):
source_info.append(("module_file", module.__file__))

# Try to get source code location
try:
source_lines, start_line = inspect.getsourcelines(cls)
source_info.append(("source_lines", "".join(source_lines)))
source_info.append(("start_line", start_line))
except (TypeError, OSError):
pass

# For classes defined in functions, include function information
if hasattr(cls, "__qualname__") and "." in cls.__qualname__:
parts = cls.__qualname__.split(".")
if len(parts) >= 2 and parts[-2].endswith(">"):
# This suggests the class is defined in a function
func_qualname = parts[-2]
if "<locals>" in func_qualname:
# Extract function name
func_name = func_qualname.split(".<locals>")[0]
source_info.append(("defined_in_function", func_name))

if source_info:
components.append(("__source_info__", tuple(source_info)))
except Exception:
# Fallback if source extraction fails
pass

# Create a deterministic string representation
repr_components = []
for component in components:
key, value = component
try:
repr_value = repr(value)
except Exception:
repr_value = f"<{type(value).__name__}>"
repr_components.append(f"{key}:{repr_value}")

# Sort for determinism
repr_components.sort()

# Create a deterministic hash
m = hashlib.sha256()
for component in repr_components:
m.update(component.encode("utf-8"))

return m.hexdigest()


def _get_or_create_tracker_id(class_def, id_generator):
with _DYNAMIC_CLASS_TRACKER_LOCK:
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
Expand Down
28 changes: 14 additions & 14 deletions sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ def dumps(
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False,
enable_stable_code_identifier_pickling=False) -> bytes:
enable_stable_code_identifier_pickling=False,
config=None) -> bytes:

kwargs = {
'enable_trace': enable_trace,
'use_zlib': use_zlib,
'enable_best_effort_determinism': enable_best_effort_determinism,
}

if (desired_pickle_lib == cloudpickle_pickler):
return cloudpickle_pickler.dumps(
o,
enable_trace=enable_trace,
use_zlib=use_zlib,
enable_best_effort_determinism=enable_best_effort_determinism,
enable_stable_code_identifier_pickling=
enable_stable_code_identifier_pickling,
)
return desired_pickle_lib.dumps(
o,
enable_trace=enable_trace,
use_zlib=use_zlib,
enable_best_effort_determinism=enable_best_effort_determinism)
pickling_key = 'enable_stable_code_identifier_pickling'
kwargs[pickling_key] = enable_stable_code_identifier_pickling
if config is not None:
kwargs['config'] = config
return cloudpickle_pickler.dumps(o, **kwargs)
return desired_pickle_lib.dumps(o, **kwargs)


def loads(encoded, enable_trace=True, use_zlib=False):
Expand Down
Loading