diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index ae34f9f3a..850acc28e 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -29,6 +29,9 @@ from orbax.export.typing import PyTree import tensorflow as tf +from .learning.brain.tfrt.python.saved_model.config import config as config_validator +from .learning.brain.tfrt.saved_model.config import compilation_config_pb2 + ApplyFn = orbax_export_typing.ApplyFn @@ -158,6 +161,23 @@ def _jax2obm_kwargs_to_options( ), ) + def _verify_xla_gpu_flags( + self, xla_flags_per_platform: Mapping[str, Sequence[str]] + ) -> None: + """Verifies that only stable xla_gpu_flags are used for GPU platforms.""" + for platform, flags in xla_flags_per_platform.items(): + if platform.lower() in ('cuda', 'rocm') and flags: + compilation_config = compilation_config_pb2.CompilationConfig( + xla_gpu_flags=flags, + ) + try: + config_validator.validate_compilation_config(compilation_config) + except Exception as e: + raise ValueError( + 'XLA GPU flag validation failed. Ensure only stable flags are' + ' used. See + ) from e + def _normalize_apply_fn_map( self, apply_fn: ( diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 6c6331be4..f459829da 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -383,6 +383,62 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization): with self.subTest('test_weights_b_dtype'): self.assertEqual(module.model_params['b'].dtype, expected_dtype) + def test_obm_module_gpu_xla_flags_success(self): + param_shape = (2, 5) + param_dtype = jnp.dtype(jnp.float32) + param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype) + model_function_name = 'simple_add' + + jax2obm_kwargs = { + constants.CHECKPOINT_PATH: 'checkpoint_path', + constants.NATIVE_SERIALIZATION_PLATFORMS: ('cuda',), + constants.XLA_FLAGS_PER_PLATFORM: { + 'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true'] + }, + } + + orbax_model_module = obm_module.ObmModule( + params=param_spec, + apply_fn={model_function_name: simple_add}, + jax2obm_kwargs=jax2obm_kwargs, + ) + + xla_compile_options_map = ( + orbax_model_module.xla_compile_options_per_platform + ) + self.assertIsNotNone(xla_compile_options_map) + build_options_cuda = xla_compile_options_map.map['cuda'] + self.assertIn( + 'xla_gpu_enable_latency_hiding_scheduler', + build_options_cuda.env_option_overrides, + ) + self.assertTrue( + build_options_cuda.env_option_overrides[ + 'xla_gpu_enable_latency_hiding_scheduler' + ].bool_field + ) + + def test_obm_module_gpu_xla_flags_rejection(self): + param_shape = (2, 5) + param_dtype = jnp.dtype(jnp.float32) + param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype) + model_function_name = 'simple_add' + + jax2obm_kwargs = { + constants.CHECKPOINT_PATH: 'checkpoint_path', + constants.NATIVE_SERIALIZATION_PLATFORMS: ('cuda',), + constants.XLA_FLAGS_PER_PLATFORM: { + 'cuda': ['--xla_gpu_experimental_unstable_flag=true'] + }, + } + + with self.assertRaisesRegex(ValueError, 'XLA GPU flag validation failed'): + obm_module.ObmModule( + params=param_spec, + apply_fn={model_function_name: simple_add}, + jax2obm_kwargs=jax2obm_kwargs, + ) + if __name__ == '__main__': absltest.main() diff --git a/model/orbax/experimental/model/core/python/compile_options_util.py b/model/orbax/experimental/model/core/python/compile_options_util.py index 7c572b249..9e6deac36 100644 --- a/model/orbax/experimental/model/core/python/compile_options_util.py +++ b/model/orbax/experimental/model/core/python/compile_options_util.py @@ -38,6 +38,35 @@ } +def _parse_env_option_overrides( + xla_flags: Sequence[str], +) -> dict[str, compile_options_pb2.OptionOverrideProto]: + """Parses a list of XLA flags into a dictionary of OptionOverrideProto.""" + overrides = {} + for flag in xla_flags: + if not flag.startswith('--'): + raise ValueError(f"Flag {flag} must start with '--'") + + key, value = flag[2:].split('=', 1) + override_proto = compile_options_pb2.OptionOverrideProto() + + # Infer type (similar to C++ CreateEnvironmentOptionOverridesFromFlags) + if value.lower() == 'true': + override_proto.bool_field = True + elif value.lower() == 'false': + override_proto.bool_field = False + elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()): + override_proto.int_field = int(value) + else: + try: + override_proto.double_field = float(value) + except ValueError: + override_proto.string_field = value + + overrides[key] = override_proto + return overrides + + def _generate_tpu_compilation_env( xla_flags_overrides: Sequence[str] | None = None, ) -> xla_pb2.CompilationEnvironmentsProto: @@ -202,11 +231,27 @@ def generate_xla_compile_options( else: xla_flags_overrides = None compile_environment = _generate_tpu_compilation_env(xla_flags_overrides) + compile_options_map.map[platform.lower()].CopyFrom( + _generate_compilation_options(compile_environment, jax_mesh) + ) + elif platform.lower() in ('cuda', 'rocm'): + compile_environment = xla_pb2.CompilationEnvironmentsProto() + compile_options = _generate_compilation_options( + compile_environment, jax_mesh + ) + if xla_flags_per_platform: + gpu_flags = xla_flags_per_platform.get(platform, None) + if gpu_flags: + _validate_xla_flags_setting(gpu_flags, persist_xla_flags) + overrides_map = _parse_env_option_overrides(gpu_flags) + for k, v in overrides_map.items(): + compile_options.env_option_overrides[k].CopyFrom(v) + compile_options_map.map[platform.lower()].CopyFrom(compile_options) else: - compile_environment = None - compile_options_map.map[platform.lower()].CopyFrom( - _generate_compilation_options(compile_environment, jax_mesh) - ) + compile_environment = xla_pb2.CompilationEnvironmentsProto() + compile_options_map.map[platform.lower()].CopyFrom( + _generate_compilation_options(compile_environment, jax_mesh) + ) if not persist_xla_flags: for compile_options in compile_options_map.map.values(): compile_options.executable_build_options.comp_envs.Clear()