Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
Expand Down Expand Up @@ -191,7 +192,12 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -246,7 +252,14 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
# Get available memory info
avail_mem, total_mem = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
2 changes: 1 addition & 1 deletion lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye
transformer_layer_list=transformer_layer_list,
weight_dir=weight_dir,
) # noqa
worker = int(os.environ.get("LOADWORKER", 1))
worker = int(os.environ.get("LOADWORKER", 16))
with Pool(worker) as p:
iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1)
desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
from .fused_moe_weight_ep import FusedMoeWeightEP
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from typing import Dict
from .base_weight import BaseWeightTpl
from lightllm.utils.dist_utils import get_current_device_id


class ParameterWeight(BaseWeightTpl):
def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight = None
self.bias = None

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id())

def verify_load(self):
load_ok = True
# Verify weight. The weight must be not None.
load_ok = load_ok and self.weight is not None
# Verify bias. If bias_name is set, it must be not None.
if self.bias_name is not None:
load_ok = load_ok and self.bias is not None
return load_ok


class TpParameterWeight(ParameterWeight):
def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None):
super().__init__(weight_name, data_type, bias_name)
self.split_n_embed = split_n_embed

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id())
if self.bias_name in weights:
self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id())
91 changes: 91 additions & 0 deletions lightllm/common/basemodel/triton_kernel/alloc_buffer_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
import triton
import triton.language as tl


@triton.jit
def alloc_buffer_for_req_kernel(
req_index_ptr, # [num_reqs] - indices of requests to allocate buffers for
buffer_indexes_ptr, # [num_reqs] - buffer indices to assign (from CPU)
req_to_buffer_index_ptr, # [max_request_num + 1] - tensor mapping req_idx to buffer_idx
num_reqs, # number of requests to process
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# Mask for valid indices
mask = offsets < num_reqs

# Load request indices and buffer indices
req_indices = tl.load(req_index_ptr + offsets, mask=mask, other=0)
buffer_indices = tl.load(buffer_indexes_ptr + offsets, mask=mask, other=0)

# Update req_to_buffer_index[req_indices] = buffer_indices
tl.store(req_to_buffer_index_ptr + req_indices, buffer_indices, mask=mask)


def alloc_buffer_for_req_triton(
req_index: torch.Tensor, # [num_reqs] int32/int64 tensor on CUDA
buffer_indexes: torch.Tensor, # [num_reqs] int32 tensor (can be CPU or CUDA)
req_to_buffer_index: torch.Tensor, # [max_request_num + 1] int32 tensor on CUDA
):
num_reqs = req_index.shape[0]

# Ensure inputs are on CUDA
if not req_index.is_cuda:
req_index = req_index.cuda()
if not buffer_indexes.is_cuda:
buffer_indexes = buffer_indexes.cuda()

# Ensure correct dtypes
if req_index.dtype not in [torch.int32, torch.int64]:
req_index = req_index.to(torch.int32)
if buffer_indexes.dtype != torch.int32:
buffer_indexes = buffer_indexes.to(torch.int32)

# Launch kernel
BLOCK_SIZE = 256
grid = (triton.cdiv(num_reqs, BLOCK_SIZE),)

alloc_buffer_for_req_kernel[grid](
req_index,
buffer_indexes,
req_to_buffer_index,
num_reqs,
BLOCK_SIZE=BLOCK_SIZE,
)


# Convenience function that matches the original API
def alloc_buffer_for_req_wrapper(
req_manager,
req_index: list,
buffer_indexes: torch.Tensor,
):
"""
Wrapper function to integrate with ReqManagerWithBuffer.

Usage in ReqManagerWithBuffer:
def alloc_buffer_for_req(self, req_index: List[int]):
self.req_has_buffer[req_index] = True
buffer_indexes = self.mem_manager.alloc_buffer(len(req_index)) # cpu tensor
# Replace the next line with Triton kernel
# self.req_to_buffer_index[req_index] = buffer_indexes
from lightllm.common.basemodel.triton_kernel.alloc_buffer_kernel import alloc_buffer_for_req_triton
req_index_tensor = torch.tensor(req_index, dtype=torch.int32, device='cuda')
alloc_buffer_for_req_triton(
req_index_tensor,
buffer_indexes,
self.req_has_buffer,
self.req_to_buffer_index
)
"""
req_index_tensor = torch.tensor(req_index, dtype=torch.int32, device="cuda")
alloc_buffer_for_req_triton(
req_index_tensor,
buffer_indexes,
req_manager.req_has_buffer,
req_manager.req_to_buffer_index,
)
Loading