Skip to content

Commit c7fcab1

Browse files
authored
[FIX] Prevent TypeError in text-only Gemma3CausalLM and improve generate_step defaults
1. Bug Fix: TypeError during generation for text-only models The Problem: When using a Gemma3CausalLM model configured for text-only processing (i.e., with vision_encoder=None and preprocessor=None), a call to causal_lm.generate() fails with a TypeError. The root cause is that the internal generate_step method returns a dictionary containing an 'images': None key-value pair. This None value is eventually passed to ops.concatenate during the output normalization step, which does not accept None as a valid input. This workflow is common when pretraining a model from scratch. The Fix: The generate_step method has been modified to only include the 'images' key in its returned dictionary if an image tensor is actually present. This ensures that a None value is never passed to downstream functions, resolving the TypeError. Proof of Bug and Fix: The following Colab notebook demonstrates the bug with the original code and shows the successful execution after applying this fix: https://colab.research.google.com/drive/1QVk2idB6fcdYYJb1cBQGaKHe5QSGjCti?usp=sharing 2. Refactoring: Remove Hardcoded Stop Token The Problem: The internal generate_step method has a hardcoded default stop_token_ids=[106], which corresponds to the <end_of_turn> token. This is conceptually incorrect for a base architectural model, as the model itself should not have opinions about instruction-following or conversational tokens. This hardcoded value can interfere with pretraining or sampling raw text. The Fix: The method signature has been changed from stop_token_ids=[106] to stop_token_ids=None. This is a safe, non-breaking change because the public-facing Gemma3CausalLM.generate() method is already responsible for setting the appropriate stop tokens when a user specifies stop_token_ids="auto".
1 parent 93d89d8 commit c7fcab1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

keras_hub/src/models/gemma3/gemma3_causal_lm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _build_cache(
227227
)
228228
return hidden_states, cache
229229

230-
def generate_step(self, inputs, stop_token_ids=[106]):
230+
def generate_step(self, inputs, stop_token_ids=None):
231231
"""A compilable generation function for a single batch of inputs.
232232
233233
This function represents the inner, XLA-compilable, generation function
@@ -326,11 +326,14 @@ def next(prompt, cache, index):
326326
else:
327327
# Without early stopping, all locations will have been updated.
328328
padding_mask = ops.ones_like(token_ids, dtype="bool")
329-
return {
329+
output_dict = {
330330
"token_ids": token_ids,
331331
"padding_mask": padding_mask,
332-
"images": images,
333332
}
333+
if images is not None:
334+
output_dict["images"] = images
335+
336+
return output_dict
334337

335338
def generate(
336339
self,

0 commit comments

Comments
 (0)