Skip to content

Commit 28ec3da

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:vae] Small linter fixes.
PiperOrigin-RevId: 817312045
1 parent 75fd8fa commit 28ec3da

File tree

13 files changed

+474
-418
lines changed

13 files changed

+474
-418
lines changed

docs_nnx/mnist_tutorial.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b
158158
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
159159
(loss, logits), grads = grad_fn(model, batch)
160160
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
161-
optimizer.update(grads) # In-place updates.
161+
optimizer.update(model, grads) # In-place updates.
162162
163163
@nnx.jit
164164
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):

examples/gemma/input_pipeline.py

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
"""Input pipeline for a LM1B dataset."""
1616

1717
import os
18-
import typing
18+
from typing import Any
1919

20+
import tokenizer
2021
import tensorflow as tf
2122
import tensorflow_datasets as tfds
22-
import tokenizer
23-
from clu import deterministic_data
24-
25-
if typing.TYPE_CHECKING:
26-
from train import TrainConfig
2723

2824
AUTOTUNE = tf.data.experimental.AUTOTUNE
2925
Features = dict[str, tf.Tensor]
@@ -58,9 +54,9 @@ def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset:
5854

5955

6056
def pack_dataset(
61-
dataset: tf.data.Dataset,
62-
key2length: int | dict[str, int],
63-
keys: list[str] | None = None,
57+
dataset: tf.data.Dataset,
58+
key2length: int | dict[str, int],
59+
keys: list[str] | None = None,
6460
) -> tf.data.Dataset:
6561
"""Creates a 'packed' version of a dataset on-the-fly.
6662
@@ -107,8 +103,8 @@ def pack_dataset(
107103
for k in keys:
108104
if k not in shapes:
109105
raise ValueError(
110-
'Key %s not found in dataset. Available keys are %s'
111-
% (k, shapes.keys())
106+
'Key %s not found in dataset. Available keys are %s'
107+
% (k, shapes.keys())
112108
)
113109
if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types]
114110
raise ValueError('Tensors to be packed must be one-dimensional.')
@@ -122,14 +118,14 @@ def pack_dataset(
122118

123119
# trim to length
124120
dataset = dataset.map(
125-
lambda x: {k: x[k][: key2length[k]] for k in keys},
126-
num_parallel_calls=AUTOTUNE,
121+
lambda x: {k: x[k][: key2length[k]] for k in keys},
122+
num_parallel_calls=AUTOTUNE,
127123
)
128124
# Setting batch_size=length ensures that the concatenated sequences (if they
129125
# have length >=1) are sufficient to fill at least one packed example.
130126
batch_size = max(key2length.values())
131127
dataset = dataset.padded_batch(
132-
batch_size, padded_shapes={k: [-1] for k in keys}
128+
batch_size, padded_shapes={k: [-1] for k in keys}
133129
)
134130
dataset = _pack_with_tf_ops(dataset, keys, key2length)
135131

@@ -141,7 +137,7 @@ def my_fn(x):
141137

142138

143139
def _pack_with_tf_ops(
144-
dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int]
140+
dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int]
145141
) -> tf.data.Dataset:
146142
"""Helper-function for packing a dataset which has already been batched.
147143
@@ -166,8 +162,8 @@ def write_packed_example(partial, outputs):
166162
new_outputs = {}
167163
for k in keys_etc:
168164
new_outputs[k] = outputs[k].write(
169-
outputs[k].size(),
170-
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
165+
outputs[k].size(),
166+
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
171167
)
172168
return new_partial, new_outputs
173169

@@ -188,10 +184,10 @@ def map_fn(x):
188184
outputs = {}
189185
for k in keys:
190186
outputs[k] = tf.TensorArray(
191-
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
187+
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
192188
)
193189
outputs[k + '_position'] = tf.TensorArray(
194-
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
190+
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
195191
)
196192

197193
def body_fn(i, partial, outputs):
@@ -213,10 +209,10 @@ def body_fn(i, partial, outputs):
213209
one_example[k] = val
214210
for k in keys:
215211
can_append = tf.logical_and(
216-
can_append,
217-
tf.less_equal(
218-
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
219-
),
212+
can_append,
213+
tf.less_equal(
214+
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
215+
),
220216
)
221217

222218
def false_fn():
@@ -232,28 +228,28 @@ def true_fn():
232228
new_seq_len = tf.size(new_seq)
233229
new_partial[k] = tf.concat([partial[k], new_seq], 0)
234230
new_partial[k + '_position'] = tf.concat(
235-
[partial[k + '_position'], tf.range(new_seq_len)], 0
231+
[partial[k + '_position'], tf.range(new_seq_len)], 0
236232
)
237233
partial = new_partial
238234
return i + 1, partial, outputs
239235

240236
# For loop over all examples in the batch.
241-
i, partial, outputs = tf.while_loop(
242-
cond=lambda *_: True,
243-
body=body_fn,
244-
loop_vars=(i, partial, outputs),
245-
shape_invariants=(
246-
tf.TensorShape([]),
247-
{k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types]
248-
{k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types]
249-
),
250-
maximum_iterations=dynamic_batch_size,
237+
_, partial, outputs = tf.while_loop(
238+
cond=lambda *_: True,
239+
body=body_fn,
240+
loop_vars=(i, partial, outputs),
241+
shape_invariants=(
242+
tf.TensorShape([]),
243+
{k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types]
244+
{k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types]
245+
),
246+
maximum_iterations=dynamic_batch_size,
251247
)
252248
_, outputs = write_packed_example(partial, outputs)
253249
packed = {k: outputs[k].stack() for k in keys_etc}
254250
for k in keys:
255251
packed[k + '_segmentation'] = tf.cumsum(
256-
tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
252+
tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
257253
) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)
258254
return packed
259255

@@ -263,25 +259,25 @@ def true_fn():
263259

264260
def shift_data_by_truncation(x):
265261
# https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53
266-
x["inputs"] = x["inputs"][:-1]
267-
x["targets"] = x["targets"][1:]
262+
x['inputs'] = x['inputs'][:-1]
263+
x['targets'] = x['targets'][1:]
268264
return x
269265

270266

271267
# -----------------------------------------------------------------------------
272268
# Main dataset prep routines.
273269
# -----------------------------------------------------------------------------
274270
def preprocess_data(
275-
dataset,
276-
shuffle: bool,
277-
num_epochs: int | None = 1,
278-
pack_examples: bool = True,
279-
shuffle_buffer_size: int = 1024,
280-
max_length: int = 512,
281-
batch_size: int = 256,
282-
drop_remainder: bool = True,
283-
prefetch_size: int = AUTOTUNE,
284-
shift: bool = True,
271+
dataset,
272+
shuffle: bool,
273+
num_epochs: int | None = 1,
274+
pack_examples: bool = True,
275+
shuffle_buffer_size: int = 1024,
276+
max_length: int = 512,
277+
batch_size: int = 256,
278+
drop_remainder: bool = True,
279+
prefetch_size: int = AUTOTUNE,
280+
shift: bool = True,
285281
):
286282
"""Shuffle and batch/pack the given dataset."""
287283

@@ -303,18 +299,20 @@ def filter_fn(x):
303299
# Shift inputs for teacher-forced training
304300
if shift:
305301
dataset = dataset.map(
306-
shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True
302+
shift_data_by_truncation,
303+
num_parallel_calls=AUTOTUNE,
304+
deterministic=True,
307305
)
308306

309307
if pack_examples:
310308
dataset = pack_dataset(dataset, max_length)
311309
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
312310
else: # simple (static-shape) padded batching
313311
dataset = dataset.padded_batch(
314-
batch_size,
315-
padded_shapes={'inputs': max_length, 'targets': max_length},
316-
padding_values={'inputs': 0, 'targets': 0},
317-
drop_remainder=drop_remainder,
312+
batch_size,
313+
padded_shapes={'inputs': max_length, 'targets': max_length},
314+
padding_values={'inputs': 0, 'targets': 0},
315+
drop_remainder=drop_remainder,
318316
)
319317

320318
if prefetch_size:
@@ -324,10 +322,10 @@ def filter_fn(x):
324322

325323

326324
def get_datasets(
327-
config: "TrainConfig",
328-
*,
329-
n_devices: int,
330-
vocab_path: str | None = None,
325+
config: Any,
326+
*,
327+
n_devices: int,
328+
vocab_path: str | None = None,
331329
):
332330
"""Load and return dataset of batched examples for use during training."""
333331
if vocab_path is None:
@@ -343,16 +341,16 @@ def get_datasets(
343341

344342
# Tokenize data.
345343
sp_processor = tokenizer.load_or_train_tokenizer(
346-
train_data,
347-
vocab_path=vocab_path,
348-
vocab_size=config.vocab_size,
349-
max_corpus_chars=config.max_corpus_chars,
344+
train_data,
345+
vocab_path=vocab_path,
346+
vocab_size=config.vocab_size,
347+
max_corpus_chars=config.max_corpus_chars,
350348
)
351349
train_data = train_data.map(
352-
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
350+
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
353351
)
354352
eval_data = eval_data.map(
355-
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
353+
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
356354
)
357355

358356
batch_size = config.per_device_batch_size * n_devices
@@ -362,20 +360,20 @@ def get_datasets(
362360
eval_batch_size = batch_size
363361

364362
train_ds = preprocess_data(
365-
train_data,
366-
shuffle=True,
367-
num_epochs=None,
368-
pack_examples=True,
369-
batch_size=batch_size,
370-
max_length=config.max_target_length,
363+
train_data,
364+
shuffle=True,
365+
num_epochs=None,
366+
pack_examples=True,
367+
batch_size=batch_size,
368+
max_length=config.max_target_length,
371369
)
372370

373371
eval_ds = preprocess_data(
374-
eval_data,
375-
shuffle=False,
376-
pack_examples=False,
377-
batch_size=eval_batch_size,
378-
max_length=config.max_eval_target_length,
372+
eval_data,
373+
shuffle=False,
374+
pack_examples=False,
375+
batch_size=eval_batch_size,
376+
max_length=config.max_eval_target_length,
379377
)
380378

381379
return train_ds, eval_ds, sp_processor

examples/gemma/main.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,24 @@
1818
that can be easily tested and imported in Colab.
1919
"""
2020

21-
import jax
22-
import tensorflow as tf
23-
import train
24-
from absl import app, flags, logging
21+
from absl import app
22+
from absl import flags
23+
from absl import logging
2524
from clu import platform
25+
import train
26+
import jax
2627
from ml_collections import config_flags
28+
import tensorflow as tf
29+
2730

2831
FLAGS = flags.FLAGS
2932

3033
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
3134
config_flags.DEFINE_config_file(
32-
'config',
33-
'configs/default.py',
34-
'File path to the training hyperparameter configuration.',
35-
lock_config=True,
35+
'config',
36+
'configs/default.py',
37+
'File path to the training hyperparameter configuration.',
38+
lock_config=True,
3639
)
3740
flags.mark_flags_as_required(['workdir'])
3841

@@ -51,11 +54,11 @@ def main(argv):
5154
# Add a note so that we can tell which task is which JAX host.
5255
# (Depending on the platform task 0 is not guaranteed to be host 0)
5356
platform.work_unit().set_task_status(
54-
f'process_index: {jax.process_index()}, '
55-
f'process_count: {jax.process_count()}'
57+
f'process_index: {jax.process_index()}, '
58+
f'process_count: {jax.process_count()}'
5659
)
5760
platform.work_unit().create_artifact(
58-
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
61+
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
5962
)
6063

6164
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)

0 commit comments

Comments
 (0)