Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions flood/flood/facade/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,3 +2053,20 @@ def update_digit(share_value, index, value):
vals = list(str(share_value.value))
vals[index] = str(min(value, 1))
share_value.value = int("".join(vals))

def cleanup(self):
"""Clean up worker processes and release CUDA resources."""
if hasattr(self, 'processes') and self.processes:
for process in self.processes:
if process.is_alive():
process.terminate()
process.join(timeout=2)
if process.is_alive():
process.kill()
process.join()
self.processes = []


def __del__(self):
"""Destructor to ensure cleanup on object deletion."""
self.cleanup()
17 changes: 13 additions & 4 deletions flood/flood/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ def forward(
stream = streams[i]
with torch.cuda.stream(stream):
if i == 0 and self.rank == 0:
batch_meta_info.to(torch.device(0), non_blocking=True)
# Move to the device where embedding layer is located
emb_device = next(self.model.embed_tokens.parameters()).device
batch_meta_info.to(emb_device, non_blocking=True)
hidden_states = self.model.embed_tokens(batch_meta_info.input_ids)
embeddings = batch_meta_info.embeddings
if embeddings is not None:
Expand All @@ -384,6 +386,11 @@ def forward(
continue
ss, se, ds, de = emb_idx
hidden_states[ds:de] = embeddings[ie][ss:se]
# Move hidden_states to the device where first layer group is located
if len(indices) > 0:
first_layer_device = next(self.model.layers[indices[0]].parameters()).device
hidden_states = hidden_states.to(first_layer_device, non_blocking=True)
batch_meta_info.to(first_layer_device, non_blocking=True)
sync_layers[i]()

for j in indices:
Expand All @@ -396,9 +403,11 @@ def forward(
)

if i < n_devices - 1:
device = torch.device(i + 1)
hidden_states = hidden_states.to(device, non_blocking=True)
batch_meta_info.to(device, non_blocking=True)
# Move to the device where next layer group is located
if len(device_list[i + 1]) > 0:
next_device = next(self.model.layers[device_list[i + 1][0]].parameters()).device
hidden_states = hidden_states.to(next_device, non_blocking=True)
batch_meta_info.to(next_device, non_blocking=True)
else:
if self.rank == self.world_size - 1 and hidden_states is not None:
hidden_states = self.model.norm(hidden_states)
Expand Down