1515"""Input pipeline for a LM1B dataset."""
1616
1717import os
18- import typing
18+ from typing import Any
1919
20+ import tokenizer
2021import tensorflow as tf
2122import tensorflow_datasets as tfds
22- import tokenizer
23- from clu import deterministic_data
24-
25- if typing .TYPE_CHECKING :
26- from train import TrainConfig
2723
2824AUTOTUNE = tf .data .experimental .AUTOTUNE
2925Features = dict [str , tf .Tensor ]
@@ -58,9 +54,9 @@ def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset:
5854
5955
6056def pack_dataset (
61- dataset : tf .data .Dataset ,
62- key2length : int | dict [str , int ],
63- keys : list [str ] | None = None ,
57+ dataset : tf .data .Dataset ,
58+ key2length : int | dict [str , int ],
59+ keys : list [str ] | None = None ,
6460) -> tf .data .Dataset :
6561 """Creates a 'packed' version of a dataset on-the-fly.
6662
@@ -107,8 +103,8 @@ def pack_dataset(
107103 for k in keys :
108104 if k not in shapes :
109105 raise ValueError (
110- 'Key %s not found in dataset. Available keys are %s'
111- % (k , shapes .keys ())
106+ 'Key %s not found in dataset. Available keys are %s'
107+ % (k , shapes .keys ())
112108 )
113109 if not shapes [k ].is_compatible_with (tf .TensorShape ([None ])): # type: ignore[wrong-arg-types]
114110 raise ValueError ('Tensors to be packed must be one-dimensional.' )
@@ -122,14 +118,14 @@ def pack_dataset(
122118
123119 # trim to length
124120 dataset = dataset .map (
125- lambda x : {k : x [k ][: key2length [k ]] for k in keys },
126- num_parallel_calls = AUTOTUNE ,
121+ lambda x : {k : x [k ][: key2length [k ]] for k in keys },
122+ num_parallel_calls = AUTOTUNE ,
127123 )
128124 # Setting batch_size=length ensures that the concatenated sequences (if they
129125 # have length >=1) are sufficient to fill at least one packed example.
130126 batch_size = max (key2length .values ())
131127 dataset = dataset .padded_batch (
132- batch_size , padded_shapes = {k : [- 1 ] for k in keys }
128+ batch_size , padded_shapes = {k : [- 1 ] for k in keys }
133129 )
134130 dataset = _pack_with_tf_ops (dataset , keys , key2length )
135131
@@ -141,7 +137,7 @@ def my_fn(x):
141137
142138
143139def _pack_with_tf_ops (
144- dataset : tf .data .Dataset , keys : list [str ], key2length : dict [str , int ]
140+ dataset : tf .data .Dataset , keys : list [str ], key2length : dict [str , int ]
145141) -> tf .data .Dataset :
146142 """Helper-function for packing a dataset which has already been batched.
147143
@@ -166,8 +162,8 @@ def write_packed_example(partial, outputs):
166162 new_outputs = {}
167163 for k in keys_etc :
168164 new_outputs [k ] = outputs [k ].write (
169- outputs [k ].size (),
170- tf .pad (partial [k ], [[0 , key2length [k ] - tf .size (partial [k ])]]),
165+ outputs [k ].size (),
166+ tf .pad (partial [k ], [[0 , key2length [k ] - tf .size (partial [k ])]]),
171167 )
172168 return new_partial , new_outputs
173169
@@ -188,10 +184,10 @@ def map_fn(x):
188184 outputs = {}
189185 for k in keys :
190186 outputs [k ] = tf .TensorArray (
191- tf .int32 , size = 0 , dynamic_size = True , element_shape = [key2length [k ]]
187+ tf .int32 , size = 0 , dynamic_size = True , element_shape = [key2length [k ]]
192188 )
193189 outputs [k + '_position' ] = tf .TensorArray (
194- tf .int32 , size = 0 , dynamic_size = True , element_shape = [key2length [k ]]
190+ tf .int32 , size = 0 , dynamic_size = True , element_shape = [key2length [k ]]
195191 )
196192
197193 def body_fn (i , partial , outputs ):
@@ -213,10 +209,10 @@ def body_fn(i, partial, outputs):
213209 one_example [k ] = val
214210 for k in keys :
215211 can_append = tf .logical_and (
216- can_append ,
217- tf .less_equal (
218- tf .size (partial [k ]) + tf .size (one_example [k ]), key2length [k ]
219- ),
212+ can_append ,
213+ tf .less_equal (
214+ tf .size (partial [k ]) + tf .size (one_example [k ]), key2length [k ]
215+ ),
220216 )
221217
222218 def false_fn ():
@@ -232,28 +228,28 @@ def true_fn():
232228 new_seq_len = tf .size (new_seq )
233229 new_partial [k ] = tf .concat ([partial [k ], new_seq ], 0 )
234230 new_partial [k + '_position' ] = tf .concat (
235- [partial [k + '_position' ], tf .range (new_seq_len )], 0
231+ [partial [k + '_position' ], tf .range (new_seq_len )], 0
236232 )
237233 partial = new_partial
238234 return i + 1 , partial , outputs
239235
240236 # For loop over all examples in the batch.
241- i , partial , outputs = tf .while_loop (
242- cond = lambda * _ : True ,
243- body = body_fn ,
244- loop_vars = (i , partial , outputs ),
245- shape_invariants = (
246- tf .TensorShape ([]),
247- {k : tf .TensorShape ([None ]) for k in keys_etc }, # type: ignore[wrong-arg-types]
248- {k : tf .TensorShape (None ) for k in keys_etc }, # type: ignore[wrong-arg-types]
249- ),
250- maximum_iterations = dynamic_batch_size ,
237+ _ , partial , outputs = tf .while_loop (
238+ cond = lambda * _ : True ,
239+ body = body_fn ,
240+ loop_vars = (i , partial , outputs ),
241+ shape_invariants = (
242+ tf .TensorShape ([]),
243+ {k : tf .TensorShape ([None ]) for k in keys_etc }, # type: ignore[wrong-arg-types]
244+ {k : tf .TensorShape (None ) for k in keys_etc }, # type: ignore[wrong-arg-types]
245+ ),
246+ maximum_iterations = dynamic_batch_size ,
251247 )
252248 _ , outputs = write_packed_example (partial , outputs )
253249 packed = {k : outputs [k ].stack () for k in keys_etc }
254250 for k in keys :
255251 packed [k + '_segmentation' ] = tf .cumsum (
256- tf .cast (tf .equal (packed [k + '_position' ], 0 ), tf .int32 ), axis = 1
252+ tf .cast (tf .equal (packed [k + '_position' ], 0 ), tf .int32 ), axis = 1
257253 ) * tf .cast (tf .not_equal (packed [k ], 0 ), tf .int32 )
258254 return packed
259255
@@ -263,25 +259,25 @@ def true_fn():
263259
264260def shift_data_by_truncation (x ):
265261 # https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53
266- x [" inputs" ] = x [" inputs" ][:- 1 ]
267- x [" targets" ] = x [" targets" ][1 :]
262+ x [' inputs' ] = x [' inputs' ][:- 1 ]
263+ x [' targets' ] = x [' targets' ][1 :]
268264 return x
269265
270266
271267# -----------------------------------------------------------------------------
272268# Main dataset prep routines.
273269# -----------------------------------------------------------------------------
274270def preprocess_data (
275- dataset ,
276- shuffle : bool ,
277- num_epochs : int | None = 1 ,
278- pack_examples : bool = True ,
279- shuffle_buffer_size : int = 1024 ,
280- max_length : int = 512 ,
281- batch_size : int = 256 ,
282- drop_remainder : bool = True ,
283- prefetch_size : int = AUTOTUNE ,
284- shift : bool = True ,
271+ dataset ,
272+ shuffle : bool ,
273+ num_epochs : int | None = 1 ,
274+ pack_examples : bool = True ,
275+ shuffle_buffer_size : int = 1024 ,
276+ max_length : int = 512 ,
277+ batch_size : int = 256 ,
278+ drop_remainder : bool = True ,
279+ prefetch_size : int = AUTOTUNE ,
280+ shift : bool = True ,
285281):
286282 """Shuffle and batch/pack the given dataset."""
287283
@@ -303,18 +299,20 @@ def filter_fn(x):
303299 # Shift inputs for teacher-forced training
304300 if shift :
305301 dataset = dataset .map (
306- shift_data_by_truncation , num_parallel_calls = AUTOTUNE , deterministic = True
302+ shift_data_by_truncation ,
303+ num_parallel_calls = AUTOTUNE ,
304+ deterministic = True ,
307305 )
308306
309307 if pack_examples :
310308 dataset = pack_dataset (dataset , max_length )
311309 dataset = dataset .batch (batch_size , drop_remainder = drop_remainder )
312310 else : # simple (static-shape) padded batching
313311 dataset = dataset .padded_batch (
314- batch_size ,
315- padded_shapes = {'inputs' : max_length , 'targets' : max_length },
316- padding_values = {'inputs' : 0 , 'targets' : 0 },
317- drop_remainder = drop_remainder ,
312+ batch_size ,
313+ padded_shapes = {'inputs' : max_length , 'targets' : max_length },
314+ padding_values = {'inputs' : 0 , 'targets' : 0 },
315+ drop_remainder = drop_remainder ,
318316 )
319317
320318 if prefetch_size :
@@ -324,10 +322,10 @@ def filter_fn(x):
324322
325323
326324def get_datasets (
327- config : "TrainConfig" ,
328- * ,
329- n_devices : int ,
330- vocab_path : str | None = None ,
325+ config : Any ,
326+ * ,
327+ n_devices : int ,
328+ vocab_path : str | None = None ,
331329):
332330 """Load and return dataset of batched examples for use during training."""
333331 if vocab_path is None :
@@ -343,16 +341,16 @@ def get_datasets(
343341
344342 # Tokenize data.
345343 sp_processor = tokenizer .load_or_train_tokenizer (
346- train_data ,
347- vocab_path = vocab_path ,
348- vocab_size = config .vocab_size ,
349- max_corpus_chars = config .max_corpus_chars ,
344+ train_data ,
345+ vocab_path = vocab_path ,
346+ vocab_size = config .vocab_size ,
347+ max_corpus_chars = config .max_corpus_chars ,
350348 )
351349 train_data = train_data .map (
352- tokenizer .TokenizeOp (sp_processor ), num_parallel_calls = AUTOTUNE
350+ tokenizer .TokenizeOp (sp_processor ), num_parallel_calls = AUTOTUNE
353351 )
354352 eval_data = eval_data .map (
355- tokenizer .TokenizeOp (sp_processor ), num_parallel_calls = AUTOTUNE
353+ tokenizer .TokenizeOp (sp_processor ), num_parallel_calls = AUTOTUNE
356354 )
357355
358356 batch_size = config .per_device_batch_size * n_devices
@@ -362,20 +360,20 @@ def get_datasets(
362360 eval_batch_size = batch_size
363361
364362 train_ds = preprocess_data (
365- train_data ,
366- shuffle = True ,
367- num_epochs = None ,
368- pack_examples = True ,
369- batch_size = batch_size ,
370- max_length = config .max_target_length ,
363+ train_data ,
364+ shuffle = True ,
365+ num_epochs = None ,
366+ pack_examples = True ,
367+ batch_size = batch_size ,
368+ max_length = config .max_target_length ,
371369 )
372370
373371 eval_ds = preprocess_data (
374- eval_data ,
375- shuffle = False ,
376- pack_examples = False ,
377- batch_size = eval_batch_size ,
378- max_length = config .max_eval_target_length ,
372+ eval_data ,
373+ shuffle = False ,
374+ pack_examples = False ,
375+ batch_size = eval_batch_size ,
376+ max_length = config .max_eval_target_length ,
379377 )
380378
381379 return train_ds , eval_ds , sp_processor
0 commit comments