-
Notifications
You must be signed in to change notification settings - Fork 158
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
Actual Behavior
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:106, in load_and_create_model(file_dir, model_class, config, key_mapping, mesh, preprocess_fn, dtype)
105 v = sf.get_tensor(k_name)
--> 106 jax_key_mapped, transform = torch_key_to_jax_key(key_map, k_name)
108 if transform is not None:
File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:35, in torch_key_to_jax_key(mapping, source_key)
34 if len(subs) != 1:
---> 35 raise ValueError(f"Only one key should be found: {subs} for {source_key}")
36 else:
ValueError: Only one key should be found: [] for language_model.model.layers.10.input_layernorm.weight
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[7], line 9
7 mesh = jax.make_mesh(*MESH)
8 with mesh:
----> 9 gemma3 = params_lib.create_model_from_safe_tensors(
10 MODEL_CP_PATH, config, mesh
11 )
12 nnx.display(gemma3)
File /usr/local/lib/python3.12/site-packages/tunix/models/gemma3/params_safetensors.py:200, in create_model_from_safe_tensors(file_dir, config, mesh, dtype)
194 def create_model_from_safe_tensors(
195 file_dir: str,
196 config: model_lib.ModelConfig,
197 mesh: jax.sharding.Mesh | None = None,
198 dtype: jnp.dtype | None = None,
199 ):
--> 200 return safetensors_loader.load_and_create_model(
201 file_dir=file_dir,
202 model_class=model_lib.Gemma3,
203 config=config,
204 key_mapping=_get_key_and_transform_mapping,
205 mesh=mesh,
206 preprocess_fn=_make_preprocess_fn(config),
207 dtype=dtype,
208 )
File /usr/local/lib/python3.12/site-packages/tunix/models/safetensors_loader.py:126, in load_and_create_model(file_dir, model_class, config, key_mapping, mesh, preprocess_fn, dtype)
123 file_loaded_tensors[jax_key_mapped] = current_arr
125 except Exception as e:
--> 126 raise RuntimeError(
127 f"Failed to load tensor {k_name} from file {f.name}: {e}"
128 ) from e
130 # Apply preprocessing if provided (e.g., for MoE expert stacking)
131 if preprocess_fn is not None:
RuntimeError: Failed to load tensor language_model.model.layers.10.input_layernorm.weight from file model-00002-of-00005.safetensors: Only one key should be found: [] for language_model.model.layers.10.input_layernorm.weight
Steps to Reproduce the Problem
run this code on Kaggle TPUs
import os
os.environ['HF_TOKEN'] = "YO-TOKEN"
import functools
from flax import nnx
from huggingface_hub import snapshot_download
import humanize
import jax
from tunix.models.gemma3 import model as model_lib
from tunix.models.gemma3 import params_safetensors as params_lib
model_id = "google/gemma-3-12b-it"
ignore_patterns = [
"*.pth", # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")
MODEL_CP_PATH = local_model_path
config = (
model_lib.ModelConfig.gemma3_12b()
) # pick correponding config based on model version
MESH = [(8, 1), ("fsdp", "tp")] # update this based on your # TPU devices
mesh = jax.make_mesh(*MESH)
with mesh:
gemma3 = params_lib.create_model_from_safe_tensors(
MODEL_CP_PATH, config, mesh
)
nnx.display(gemma3)Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working