diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py index 63598481a2..2ccf177b89 100644 --- a/transformer_engine/jax/version_utils.py +++ b/transformer_engine/jax/version_utils.py @@ -17,7 +17,7 @@ @lru_cache(maxsize=None) def jax_version_meet_requirement(version: str): """Return True if the installed JAX version is >= the required version.""" - jax_version = PkgVersion(get_pkg_version("jax")) + jax_version = PkgVersion(PkgVersion(get_pkg_version("jax")).public) jax_version_required = PkgVersion(version) return jax_version >= jax_version_required diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 1e7bdaac84..cb5c752f4a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -90,7 +90,7 @@ _flash_attn_varlen_fwd = None _flash_attn_varlen_bwd = None try: - fa_utils.version = PkgVersion(get_pkg_version("flash-attn")) + fa_utils.version = PkgVersion(PkgVersion(get_pkg_version("flash-attn")).public) except PackageNotFoundError: pass # only print warning if use_flash_attention_2 = True in get_attention_backend else: @@ -132,7 +132,7 @@ fa_utils.version, ) try: - fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3")) + fa_utils.fa3_version = PkgVersion(PkgVersion(get_pkg_version("flash-attn-3")).public) except PackageNotFoundError: flash_attn_func_v3 = None flash_attn_varlen_func_v3 = None