diff --git a/flood/flood/facade/llm.py b/flood/flood/facade/llm.py index ba6fad6..8acd7cd 100644 --- a/flood/flood/facade/llm.py +++ b/flood/flood/facade/llm.py @@ -2053,3 +2053,20 @@ def update_digit(share_value, index, value): vals = list(str(share_value.value)) vals[index] = str(min(value, 1)) share_value.value = int("".join(vals)) + + def cleanup(self): + """Clean up worker processes and release CUDA resources.""" + if hasattr(self, 'processes') and self.processes: + for process in self.processes: + if process.is_alive(): + process.terminate() + process.join(timeout=2) + if process.is_alive(): + process.kill() + process.join() + self.processes = [] + + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + self.cleanup() \ No newline at end of file diff --git a/flood/flood/models/modeling_qwen3.py b/flood/flood/models/modeling_qwen3.py index 29ec712..ddd223f 100644 --- a/flood/flood/models/modeling_qwen3.py +++ b/flood/flood/models/modeling_qwen3.py @@ -374,7 +374,9 @@ def forward( stream = streams[i] with torch.cuda.stream(stream): if i == 0 and self.rank == 0: - batch_meta_info.to(torch.device(0), non_blocking=True) + # Move to the device where embedding layer is located + emb_device = next(self.model.embed_tokens.parameters()).device + batch_meta_info.to(emb_device, non_blocking=True) hidden_states = self.model.embed_tokens(batch_meta_info.input_ids) embeddings = batch_meta_info.embeddings if embeddings is not None: @@ -384,6 +386,11 @@ def forward( continue ss, se, ds, de = emb_idx hidden_states[ds:de] = embeddings[ie][ss:se] + # Move hidden_states to the device where first layer group is located + if len(indices) > 0: + first_layer_device = next(self.model.layers[indices[0]].parameters()).device + hidden_states = hidden_states.to(first_layer_device, non_blocking=True) + batch_meta_info.to(first_layer_device, non_blocking=True) sync_layers[i]() for j in indices: @@ -396,9 +403,11 @@ def forward( ) if i < n_devices - 1: - device = torch.device(i + 1) - hidden_states = hidden_states.to(device, non_blocking=True) - batch_meta_info.to(device, non_blocking=True) + # Move to the device where next layer group is located + if len(device_list[i + 1]) > 0: + next_device = next(self.model.layers[device_list[i + 1][0]].parameters()).device + hidden_states = hidden_states.to(next_device, non_blocking=True) + batch_meta_info.to(next_device, non_blocking=True) else: if self.rank == self.world_size - 1 and hidden_states is not None: hidden_states = self.model.norm(hidden_states)