Skip to content

Commit a7d6f6f

Browse files
committed
enable fine tuning on HPU
1 parent 7682500 commit a7d6f6f

5 files changed

Lines changed: 128 additions & 20 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
from functools import lru_cache
4+
@lru_cache(maxsize=None)
5+
def is_torch_hpu_available() -> bool:
6+
try:
7+
import habana_frameworks.torch.core # noqa: F401
8+
except ImportError:
9+
return False
10+
return True
11+

src/instructlab/training/main_ds.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@
4242
UserWarning,
4343
)
4444

45+
from instructlab.training.hpu_utils import is_torch_hpu_available
46+
47+
if is_torch_hpu_available():
48+
import habana_frameworks.torch.core as htcore
49+
import habana_frameworks.torch.distributed.hccl
50+
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
51+
adapt_transformers_to_gaudi()
52+
4553
# Third Party
4654
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
4755
from torch.utils.data import DataLoader
@@ -221,7 +229,22 @@ def setup_model(
221229
)
222230
model.config.eos_token_id = tokenizer.eos_token_id
223231

224-
if "ForCausalLM" not in model.__class__.__name__:
232+
if not is_torch_hpu_available():
233+
class_name = model.__class__.__name__
234+
else:
235+
class_name = model._orig_mod.__class__.__name__ if model.__class__.__name__ == 'OptimizedModule' else model.__class__.__name__
236+
237+
replace_no_split_modules = {
238+
'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',]
239+
}
240+
241+
if class_name in replace_no_split_modules:
242+
if model.__class__.__name__ == 'OptimizedModule':
243+
model._orig_mod._no_split_modules = replace_no_split_modules[class_name]
244+
else:
245+
model._no_split_modules = replace_no_split_modules[class_name]
246+
247+
if "ForCausalLM" not in class_name:
225248
raise ValueError(
226249
f"Model class name: {model.__class__.__name__} is not supported."
227250
)
@@ -271,6 +294,11 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
271294
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
272295

273296
accelerator = setup_accelerator(args, model, grad_accum)
297+
298+
if is_torch_hpu_available():
299+
accelerator.state.fsdp_plugin.use_orig_params=True
300+
accelerator.state.fsdp_plugin.sync_module_states=True
301+
274302
if args.distributed_training_framework == DistributedBackend.FSDP.value:
275303
model = accelerator.prepare(model)
276304
optimizer = setup_optimizer(args, model)
@@ -413,10 +441,19 @@ def train(
413441
total_length = float(torch.tensor([batch.pop("total_length")]))
414442
if not args.use_dolomite:
415443
for k in batch:
416-
batch[k] = batch[k].to(local_rank)
444+
batch[k] = batch[k].to('hpu' if is_torch_hpu_available() else local_rank)
445+
446+
hpu_args = []
447+
if is_torch_hpu_available():
448+
hpu_args = {
449+
"use_flash_attention":True,
450+
"lazy_mode":False,
451+
}
452+
417453
output = model(
418454
**batch,
419455
use_cache=False,
456+
**hpu_args,
420457
)
421458
loss = output.loss
422459
log_loss = loss.detach().item()
@@ -453,8 +490,14 @@ def train(
453490
elapsed_time = time.time() - start
454491
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
455492
current_lr = lr_scheduler.get_last_lr()[0]
456-
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
457-
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
493+
494+
if is_torch_hpu_available():
495+
mem_allocated = torch.hpu.memory_allocated() / (1024**3)
496+
malloc_retries = 0
497+
else:
498+
mem_allocated = torch.cuda.memory_allocated() / (1024**3)
499+
malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
500+
458501
global_grad_norm = (
459502
model.get_global_grad_norm()
460503
if hasattr(model, "get_global_grad_norm")
@@ -476,8 +519,8 @@ def train(
476519
"rank": torch.distributed.get_rank(),
477520
"overall_throughput": overall_throughput,
478521
"lr": current_lr,
479-
"cuda_mem_allocated": cuda_mem_allocated,
480-
"cuda_malloc_retries": cuda_malloc_retries,
522+
("hpu" if is_torch_hpu_available() else "cuda") + "_mem_allocated": mem_allocated,
523+
("hpu" if is_torch_hpu_available() else "cuda") + "_malloc_retries": malloc_retries,
481524
"num_loss_counted_tokens": int(num_loss_counted_tokens),
482525
"num_tokens_rank0": int(total_length),
483526
"batch_size": int(micro_batch_size),
@@ -518,7 +561,10 @@ def train(
518561
global_step += 1
519562
if local_rank == 0:
520563
inner_pb.update(1)
521-
torch.cuda.empty_cache()
564+
565+
if not is_torch_hpu_available():
566+
torch.cuda.empty_cache()
567+
522568
if args.checkpoint_at_epoch:
523569
base_logger.debug(f"Saving checkpoint at epoch {epoch}")
524570
save_checkpoint(
@@ -575,13 +621,22 @@ def main(args):
575621
# gets converted to a timedelta of 1:40:00 if the default is kept
576622
nccl_timeout = int(os.getenv("INSTRUCTLAB_NCCL_TIMEOUT_MS", "6000000"))
577623
#### distributed init #####
578-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
624+
if is_torch_hpu_available():
625+
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
626+
else:
627+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
628+
579629
args.local_rank = int(os.environ["LOCAL_RANK"])
580630
torch.distributed.init_process_group(
581-
"nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout)
631+
"hccl" if is_torch_hpu_available() else "nccl", timeout=datetime.timedelta(milliseconds=nccl_timeout)
582632
)
583633
args.global_rank = torch.distributed.get_rank()
584-
tensor = torch.ByteTensor([False]).cuda()
634+
635+
if is_torch_hpu_available():
636+
tensor = torch.ByteTensor([False]).to('hpu')
637+
else:
638+
tensor = torch.ByteTensor([False]).cuda()
639+
585640
torch.distributed.all_reduce(tensor)
586641
torch.distributed.barrier()
587642

src/instructlab/training/multipack_sampler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import torch
3535
import torch.distributed as dist
3636

37+
from instructlab.training.utils import bucket
38+
3739

3840
def find_max_pack_len_with_padding(
3941
dataset,
@@ -211,11 +213,11 @@ def ffd_check_padding(a: np.ndarray, c: int, n: int):
211213
not_found = True
212214
for idx in range(n):
213215
# Calculate the new capacity if size is added to the bin
214-
new_capacity = max(bins_max_lengths[idx], size) * (
216+
new_capacity = bucket(max(bins_max_lengths[idx], size)) * (
215217
bins_num_samples[idx] + 1
216218
)
217219
if new_capacity <= c:
218-
bins_max_lengths[idx] = max(bins_max_lengths[idx], size)
220+
bins_max_lengths[idx] = bucket(max(bins_max_lengths[idx], size))
219221
bins_num_samples[idx] += 1
220222
not_found = False
221223
break
@@ -266,11 +268,11 @@ def ffd_with_result_padding(a: np.ndarray, c: int, start_index: int):
266268
add_new = True
267269
for idx in range(len(bins_max_lengths)):
268270
# Calculate the new capacity if size is added to the bin
269-
new_capacity = max(bins_max_lengths[idx], size) * (
271+
new_capacity = bucket(max(bins_max_lengths[idx], size)) * (
270272
bins_num_samples[idx] + 1
271273
)
272274
if new_capacity <= c:
273-
bins_max_lengths[idx] = max(bins_max_lengths[idx], size)
275+
bins_max_lengths[idx] = bucket(max(bins_max_lengths[idx], size))
274276
bins_num_samples[idx] += 1
275277
bins_result[idx].append(indices[a_id] + start_index)
276278
add_new = False

src/instructlab/training/setup_accelerator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33

44
# Third Party
5-
from accelerate import Accelerator
65
from peft.utils.other import fsdp_auto_wrap_policy
76
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
87
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
@@ -12,6 +11,12 @@
1211
# First Party
1312
from instructlab.training.config import DeepSpeedOptions
1413
from instructlab.training.utils import get_module_class_from_name, patch_target_module
14+
from instructlab.training.hpu_utils import is_torch_hpu_available
15+
16+
if is_torch_hpu_available():
17+
from optimum.habana.accelerate import GaudiAccelerator
18+
else:
19+
from accelerate import Accelerator
1520

1621

1722
def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions):
@@ -51,7 +56,10 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
5156

5257
def get_fsdp_config(args, model: PreTrainedModel):
5358
# Third Party
54-
from accelerate.utils import FullyShardedDataParallelPlugin
59+
if is_torch_hpu_available():
60+
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
61+
else:
62+
from accelerate.utils import FullyShardedDataParallelPlugin
5563
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
5664

5765
is_lora = args.lora_r > 0
@@ -73,7 +81,7 @@ def get_fsdp_config(args, model: PreTrainedModel):
7381
prefetch_policy = (
7482
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
7583
)
76-
fsdp_plugin = FullyShardedDataParallelPlugin(
84+
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
7785
auto_wrap_policy=wrap_policy,
7886
limit_all_gathers=True,
7987
backward_prefetch=prefetch_policy,
@@ -128,7 +136,7 @@ def setup_accelerator(args, model: PreTrainedModel, grad_accum):
128136
raise ValueError(
129137
f"Unknown sharding framework: {args.distributed_training_framework}"
130138
)
131-
accelerator = Accelerator(
139+
accelerator = (GaudiAccelerator if is_torch_hpu_available() else Accelerator)(
132140
**accel_args,
133141
)
134142
accelerator.even_batches = False

src/instructlab/training/utils.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@
4343
import numpy as np
4444
import torch
4545
import torch.nn.functional as F
46+
import numba
4647

4748
# First Party
4849
from instructlab.training.config import (
4950
DistributedBackend,
5051
QuantizeDataType,
5152
TrainingArgs,
5253
)
54+
from instructlab.training.hpu_utils import is_torch_hpu_available
5355

5456
logger = logging.getLogger("instructlab.training")
5557

@@ -209,6 +211,8 @@ def listen(self):
209211

210212

211213
def supports_flash_attention(device_id=0):
214+
if is_torch_hpu_available():
215+
return False
212216
"""Check if a GPU supports FlashAttention."""
213217
major, minor = torch.cuda.get_device_capability(device_id)
214218
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
@@ -236,6 +240,30 @@ def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bo
236240
return flash_enabled
237241

238242

243+
@numba.njit
244+
def simple_bucket(length):
245+
l = length
246+
msb = 0
247+
while l > 0:
248+
msb += 1
249+
l = l // 2
250+
251+
align = (1 << (msb - 4)) if msb >= 4 else 1
252+
253+
return (length + align - 1) // align * align
254+
255+
256+
torch_hpu_available = is_torch_hpu_available()
257+
258+
@numba.njit
259+
def bucket(length):
260+
global torch_hpu_available
261+
if torch_hpu_available:
262+
return simple_bucket(length)
263+
else:
264+
return length
265+
266+
239267
def make_collate_fn(
240268
pad_token_id, use_dolomite=False, flash_enabled=True, max_batch_len=60000
241269
):
@@ -298,7 +326,7 @@ def pad_collate_fn(batch):
298326

299327
def pad_collate_fn(batch):
300328
lens = np.array([len(item["input_ids"]) for item in batch])
301-
max_len = max(lens)
329+
max_len = bucket(max(lens))
302330

303331
input_ids = torch.stack(
304332
[
@@ -411,6 +439,7 @@ def reduce_sum_forward(
411439
output_attentions=output_attentions,
412440
output_hidden_states=output_hidden_states,
413441
return_dict=return_dict,
442+
**deprecated_arguments if is_torch_hpu_available() else None,
414443
)
415444

416445
return_dict = isinstance(output, dict)
@@ -1093,7 +1122,10 @@ def set_random_seed(seed):
10931122
random.seed(seed)
10941123
np.random.seed(seed)
10951124
torch.manual_seed(seed)
1096-
torch.cuda.manual_seed_all(seed)
1125+
if is_torch_hpu_available():
1126+
torch.hpu.manual_seed_all(seed)
1127+
else:
1128+
torch.cuda.manual_seed_all(seed)
10971129

10981130

10991131
def save_checkpoint(

0 commit comments

Comments
 (0)