4242 UserWarning ,
4343 )
4444
45+ from instructlab .training .hpu_utils import is_torch_hpu_available
46+
47+ if is_torch_hpu_available ():
48+ import habana_frameworks .torch .core as htcore
49+ import habana_frameworks .torch .distributed .hccl
50+ from optimum .habana .transformers .modeling_utils import adapt_transformers_to_gaudi
51+ adapt_transformers_to_gaudi ()
52+
4553# Third Party
4654from instructlab .dolomite .hf_models import GPTDolomiteForCausalLM
4755from torch .utils .data import DataLoader
@@ -221,7 +229,22 @@ def setup_model(
221229 )
222230 model .config .eos_token_id = tokenizer .eos_token_id
223231
224- if "ForCausalLM" not in model .__class__ .__name__ :
232+ if not is_torch_hpu_available ():
233+ class_name = model .__class__ .__name__
234+ else :
235+ class_name = model ._orig_mod .__class__ .__name__ if model .__class__ .__name__ == 'OptimizedModule' else model .__class__ .__name__
236+
237+ replace_no_split_modules = {
238+ 'GaudiLlamaForCausalLM' : ['GaudiLlamaDecoderLayer' ,]
239+ }
240+
241+ if class_name in replace_no_split_modules :
242+ if model .__class__ .__name__ == 'OptimizedModule' :
243+ model ._orig_mod ._no_split_modules = replace_no_split_modules [class_name ]
244+ else :
245+ model ._no_split_modules = replace_no_split_modules [class_name ]
246+
247+ if "ForCausalLM" not in class_name :
225248 raise ValueError (
226249 f"Model class name: { model .__class__ .__name__ } is not supported."
227250 )
@@ -271,6 +294,11 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
271294 model .get_input_embeddings ().register_forward_hook (make_inputs_require_grad )
272295
273296 accelerator = setup_accelerator (args , model , grad_accum )
297+
298+ if is_torch_hpu_available ():
299+ accelerator .state .fsdp_plugin .use_orig_params = True
300+ accelerator .state .fsdp_plugin .sync_module_states = True
301+
274302 if args .distributed_training_framework == DistributedBackend .FSDP .value :
275303 model = accelerator .prepare (model )
276304 optimizer = setup_optimizer (args , model )
@@ -413,10 +441,19 @@ def train(
413441 total_length = float (torch .tensor ([batch .pop ("total_length" )]))
414442 if not args .use_dolomite :
415443 for k in batch :
416- batch [k ] = batch [k ].to (local_rank )
444+ batch [k ] = batch [k ].to ('hpu' if is_torch_hpu_available () else local_rank )
445+
446+ hpu_args = []
447+ if is_torch_hpu_available ():
448+ hpu_args = {
449+ "use_flash_attention" :True ,
450+ "lazy_mode" :False ,
451+ }
452+
417453 output = model (
418454 ** batch ,
419455 use_cache = False ,
456+ ** hpu_args ,
420457 )
421458 loss = output .loss
422459 log_loss = loss .detach ().item ()
@@ -453,8 +490,14 @@ def train(
453490 elapsed_time = time .time () - start
454491 overall_throughput = args .samples_per_gpu * world_size / elapsed_time
455492 current_lr = lr_scheduler .get_last_lr ()[0 ]
456- cuda_mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
457- cuda_malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
493+
494+ if is_torch_hpu_available ():
495+ mem_allocated = torch .hpu .memory_allocated () / (1024 ** 3 )
496+ malloc_retries = 0
497+ else :
498+ mem_allocated = torch .cuda .memory_allocated () / (1024 ** 3 )
499+ malloc_retries = torch .cuda .memory_stats ()["num_alloc_retries" ]
500+
458501 global_grad_norm = (
459502 model .get_global_grad_norm ()
460503 if hasattr (model , "get_global_grad_norm" )
@@ -476,8 +519,8 @@ def train(
476519 "rank" : torch .distributed .get_rank (),
477520 "overall_throughput" : overall_throughput ,
478521 "lr" : current_lr ,
479- "cuda_mem_allocated" : cuda_mem_allocated ,
480- "cuda_malloc_retries" : cuda_malloc_retries ,
522+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_mem_allocated" : mem_allocated ,
523+ ( "hpu" if is_torch_hpu_available () else "cuda" ) + "_malloc_retries" : malloc_retries ,
481524 "num_loss_counted_tokens" : int (num_loss_counted_tokens ),
482525 "num_tokens_rank0" : int (total_length ),
483526 "batch_size" : int (micro_batch_size ),
@@ -518,7 +561,10 @@ def train(
518561 global_step += 1
519562 if local_rank == 0 :
520563 inner_pb .update (1 )
521- torch .cuda .empty_cache ()
564+
565+ if not is_torch_hpu_available ():
566+ torch .cuda .empty_cache ()
567+
522568 if args .checkpoint_at_epoch :
523569 base_logger .debug (f"Saving checkpoint at epoch { epoch } " )
524570 save_checkpoint (
@@ -575,13 +621,22 @@ def main(args):
575621 # gets converted to a timedelta of 1:40:00 if the default is kept
576622 nccl_timeout = int (os .getenv ("INSTRUCTLAB_NCCL_TIMEOUT_MS" , "6000000" ))
577623 #### distributed init #####
578- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
624+ if is_torch_hpu_available ():
625+ torch .hpu .set_device (int (os .environ ["LOCAL_RANK" ]))
626+ else :
627+ torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
628+
579629 args .local_rank = int (os .environ ["LOCAL_RANK" ])
580630 torch .distributed .init_process_group (
581- "nccl" , timeout = datetime .timedelta (milliseconds = nccl_timeout )
631+ "hccl" if is_torch_hpu_available () else " nccl" , timeout = datetime .timedelta (milliseconds = nccl_timeout )
582632 )
583633 args .global_rank = torch .distributed .get_rank ()
584- tensor = torch .ByteTensor ([False ]).cuda ()
634+
635+ if is_torch_hpu_available ():
636+ tensor = torch .ByteTensor ([False ]).to ('hpu' )
637+ else :
638+ tensor = torch .ByteTensor ([False ]).cuda ()
639+
585640 torch .distributed .all_reduce (tensor )
586641 torch .distributed .barrier ()
587642
0 commit comments