Skip to content

Gemma3 12B loading error #611

@heydaari

Description

@heydaari

Actual Behavior

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:106, in load_and_create_model(file_dir, model_class, config, key_mapping, mesh, preprocess_fn, dtype)
    105 v = sf.get_tensor(k_name)
--> 106 jax_key_mapped, transform = torch_key_to_jax_key(key_map, k_name)
    108 if transform is not None:

File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:35, in torch_key_to_jax_key(mapping, source_key)
     34 if len(subs) != 1:
---> 35   raise ValueError(f"Only one key should be found: {subs} for {source_key}")
     36 else:

ValueError: Only one key should be found: [] for language_model.model.layers.10.input_layernorm.weight

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[7], line 9
      7 mesh = jax.make_mesh(*MESH)
      8 with mesh:
----> 9   gemma3 = params_lib.create_model_from_safe_tensors(
     10       MODEL_CP_PATH, config, mesh
     11   )
     12   nnx.display(gemma3)

File /usr/local/lib/python3.12/site-packages/tunix/models/gemma3/params_safetensors.py:200, in create_model_from_safe_tensors(file_dir, config, mesh, dtype)
    194 def create_model_from_safe_tensors(
    195     file_dir: str,
    196     config: model_lib.ModelConfig,
    197     mesh: jax.sharding.Mesh | None = None,
    198     dtype: jnp.dtype | None = None,
    199 ):
--> 200   return safetensors_loader.load_and_create_model(
    201       file_dir=file_dir,
    202       model_class=model_lib.Gemma3,
    203       config=config,
    204       key_mapping=_get_key_and_transform_mapping,
    205       mesh=mesh,
    206       preprocess_fn=_make_preprocess_fn(config),
    207       dtype=dtype,
    208   )

File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:126, in load_and_create_model(file_dir, model_class, config, key_mapping, mesh, preprocess_fn, dtype)
    123       file_loaded_tensors[jax_key_mapped] = current_arr
    125     except Exception as e:
--> 126       raise RuntimeError(
    127           f"Failed to load tensor {k_name} from file {f.name}: {e}"
    128       ) from e
    130 # Apply preprocessing if provided (e.g., for MoE expert stacking)
    131 if preprocess_fn is not None:

RuntimeError: Failed to load tensor language_model.model.layers.10.input_layernorm.weight from file model-00002-of-00005.safetensors: Only one key should be found: [] for language_model.model.layers.10.input_layernorm.weight

Steps to Reproduce the Problem
run this code on Kaggle TPUs

import os
os.environ['HF_TOKEN'] = "YO-TOKEN"

import functools
from flax import nnx
from huggingface_hub import snapshot_download
import humanize
import jax
from tunix.models.gemma3 import model as model_lib
from tunix.models.gemma3 import params_safetensors as params_lib

model_id = "google/gemma-3-12b-it"
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

MODEL_CP_PATH = local_model_path

config = (
    model_lib.ModelConfig.gemma3_12b()
)  # pick correponding config based on model version
MESH = [(8, 1), ("fsdp", "tp")]  # update this based on your # TPU devices
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, config, mesh
  )
  nnx.display(gemma3)

Metadata

Metadata

Assignees

Labels

type:bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions