Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from orbax.export.typing import PyTree
import tensorflow as tf

from .learning.brain.tfrt.python.saved_model.config import config as config_validator
from .learning.brain.tfrt.saved_model.config import compilation_config_pb2

ApplyFn = orbax_export_typing.ApplyFn


Expand Down Expand Up @@ -158,6 +161,23 @@ def _jax2obm_kwargs_to_options(
),
)

def _verify_xla_gpu_flags(
self, xla_flags_per_platform: Mapping[str, Sequence[str]]
) -> None:
"""Verifies that only stable xla_gpu_flags are used for GPU platforms."""
for platform, flags in xla_flags_per_platform.items():
if platform.lower() in ('cuda', 'rocm') and flags:
compilation_config = compilation_config_pb2.CompilationConfig(
xla_gpu_flags=flags,
)
try:
config_validator.validate_compilation_config(compilation_config)
except Exception as e:
raise ValueError(
'XLA GPU flag validation failed. Ensure only stable flags are'
' used. See
) from e

def _normalize_apply_fn_map(
self,
apply_fn: (
Expand Down
56 changes: 56 additions & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,62 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
with self.subTest('test_weights_b_dtype'):
self.assertEqual(module.model_params['b'].dtype, expected_dtype)

def test_obm_module_gpu_xla_flags_success(self):
param_shape = (2, 5)
param_dtype = jnp.dtype(jnp.float32)
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
model_function_name = 'simple_add'

jax2obm_kwargs = {
constants.CHECKPOINT_PATH: 'checkpoint_path',
constants.NATIVE_SERIALIZATION_PLATFORMS: ('cuda',),
constants.XLA_FLAGS_PER_PLATFORM: {
'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true']
},
}

orbax_model_module = obm_module.ObmModule(
params=param_spec,
apply_fn={model_function_name: simple_add},
jax2obm_kwargs=jax2obm_kwargs,
)

xla_compile_options_map = (
orbax_model_module.xla_compile_options_per_platform
)
self.assertIsNotNone(xla_compile_options_map)
build_options_cuda = xla_compile_options_map.map['cuda']
self.assertIn(
'xla_gpu_enable_latency_hiding_scheduler',
build_options_cuda.env_option_overrides,
)
self.assertTrue(
build_options_cuda.env_option_overrides[
'xla_gpu_enable_latency_hiding_scheduler'
].bool_field
)

def test_obm_module_gpu_xla_flags_rejection(self):
param_shape = (2, 5)
param_dtype = jnp.dtype(jnp.float32)
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
model_function_name = 'simple_add'

jax2obm_kwargs = {
constants.CHECKPOINT_PATH: 'checkpoint_path',
constants.NATIVE_SERIALIZATION_PLATFORMS: ('cuda',),
constants.XLA_FLAGS_PER_PLATFORM: {
'cuda': ['--xla_gpu_experimental_unstable_flag=true']
},
}

with self.assertRaisesRegex(ValueError, 'XLA GPU flag validation failed'):
obm_module.ObmModule(
params=param_spec,
apply_fn={model_function_name: simple_add},
jax2obm_kwargs=jax2obm_kwargs,
)


if __name__ == '__main__':
absltest.main()
53 changes: 49 additions & 4 deletions model/orbax/experimental/model/core/python/compile_options_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,35 @@
}


def _parse_env_option_overrides(
xla_flags: Sequence[str],
) -> dict[str, compile_options_pb2.OptionOverrideProto]:
"""Parses a list of XLA flags into a dictionary of OptionOverrideProto."""
overrides = {}
for flag in xla_flags:
if not flag.startswith('--'):
raise ValueError(f"Flag {flag} must start with '--'")

key, value = flag[2:].split('=', 1)
override_proto = compile_options_pb2.OptionOverrideProto()

# Infer type (similar to C++ CreateEnvironmentOptionOverridesFromFlags)
if value.lower() == 'true':
override_proto.bool_field = True
elif value.lower() == 'false':
override_proto.bool_field = False
elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()):
override_proto.int_field = int(value)
else:
try:
override_proto.double_field = float(value)
except ValueError:
override_proto.string_field = value

overrides[key] = override_proto
return overrides


def _generate_tpu_compilation_env(
xla_flags_overrides: Sequence[str] | None = None,
) -> xla_pb2.CompilationEnvironmentsProto:
Expand Down Expand Up @@ -202,11 +231,27 @@ def generate_xla_compile_options(
else:
xla_flags_overrides = None
compile_environment = _generate_tpu_compilation_env(xla_flags_overrides)
compile_options_map.map[platform.lower()].CopyFrom(
_generate_compilation_options(compile_environment, jax_mesh)
)
elif platform.lower() in ('cuda', 'rocm'):
compile_environment = xla_pb2.CompilationEnvironmentsProto()
compile_options = _generate_compilation_options(
compile_environment, jax_mesh
)
if xla_flags_per_platform:
gpu_flags = xla_flags_per_platform.get(platform, None)
if gpu_flags:
_validate_xla_flags_setting(gpu_flags, persist_xla_flags)
overrides_map = _parse_env_option_overrides(gpu_flags)
for k, v in overrides_map.items():
compile_options.env_option_overrides[k].CopyFrom(v)
compile_options_map.map[platform.lower()].CopyFrom(compile_options)
else:
compile_environment = None
compile_options_map.map[platform.lower()].CopyFrom(
_generate_compilation_options(compile_environment, jax_mesh)
)
compile_environment = xla_pb2.CompilationEnvironmentsProto()
compile_options_map.map[platform.lower()].CopyFrom(
_generate_compilation_options(compile_environment, jax_mesh)
)
if not persist_xla_flags:
for compile_options in compile_options_map.map.values():
compile_options.executable_build_options.comp_envs.Clear()
Expand Down
Loading