Skip to content

Commit 250d7ad

Browse files
authored
Enable CC_METHOD for dsv3 by default && fix test script && fix tgi stream api (#732)
Co-authored-by: shihaobai <[email protected]>
1 parent c07e3a2 commit 250d7ad

File tree

5 files changed

+23
-17
lines changed

5 files changed

+23
-17
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
5858
if mscale_all_dim:
5959
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
6060
self.softmax_scale = self.softmax_scale * mscale * mscale
61-
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
61+
self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
6262
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
6363
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
6464
if self.enable_dp:

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def fuse_vb_o(self, layer_weight):
6666
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
6767
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
6868
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
69-
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
69+
self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
7070
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
7171
return
7272

lightllm/server/api_tgi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import json
1010

1111

12-
def format_tgi_params(params, num_beam: int):
12+
def format_tgi_params(params, num_beam: int = 1):
1313
"""
1414
tgi params format -> lightllm server params format
1515
pub(crate) struct GenerateParameters {

test/model/model_infer.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,20 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_
3838

3939
def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue):
4040
import torch
41-
from lightllm.distributed import set_custom_reduce
41+
from lightllm.distributed import custom_comm_ops
42+
from lightllm.utils.device_utils import set_current_device_id
43+
4244
import torch.distributed as dist
4345

4446
rank_id = model_kvargs["tp_rank"]
4547
world_size = model_kvargs["world_size"]
4648

4749
torch.cuda.set_device(rank_id)
48-
50+
set_current_device_id(rank_id)
4951
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
50-
set_custom_reduce()
52+
53+
custom_comm_ops.set_custom_reduce()
54+
custom_comm_ops.set_custom_gather()
5155
dist.barrier()
5256

5357
torch.cuda.empty_cache()
@@ -59,7 +63,9 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
5963
test_data = test_data.reshape(-1)
6064
test_data = torch.from_numpy(test_data).cuda()
6165

62-
b_req_idx = model_part.req_manager.alloc(batch_size).int()
66+
b_req_idx = torch.tensor(
67+
[model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
68+
)
6369
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
6470
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
6571
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
@@ -68,7 +74,8 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
6874
b_seq_len[i] = input_len
6975

7076
total_token_num = input_len * batch_size
71-
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
77+
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
78+
7279
logics = model_part.forward(
7380
batch_size,
7481
total_token_num,
@@ -89,7 +96,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
8996
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
9097
total_token_num += batch_size
9198
b_seq_len += 1
92-
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
99+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
93100
logics = model_part.forward(
94101
batch_size,
95102
total_token_num,
@@ -108,10 +115,6 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
108115
model_part.mem_manager.free_all()
109116
model_part.req_manager.free_all()
110117

111-
if rank_id == 0:
112-
print("can use mem size:", model_part.mem_manager.can_use_mem_size)
113-
print("can use req size:", model_part.req_manager.can_use_req_size)
114-
115118
b_req_idx = None
116119
b_start_loc = None
117120
b_seq_len = None
@@ -124,15 +127,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
124127

125128
prefill_start_time = time.time()
126129

127-
b_req_idx = model_part.req_manager.alloc(batch_size).int()
130+
b_req_idx = torch.tensor(
131+
[model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
132+
)
128133
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
129134
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
130135
for i in range(batch_size):
131136
b_start_loc[i] = i * input_len
132137
b_seq_len[i] = input_len
133138

134139
total_token_num = batch_size * input_len
135-
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
140+
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
136141
logics = model_part.forward(
137142
batch_size,
138143
total_token_num,
@@ -159,7 +164,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
159164
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
160165
total_token_num += batch_size
161166
b_seq_len += 1
162-
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
167+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
163168
logics = model_part.forward(
164169
batch_size,
165170
total_token_num,

test/model/test_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightllm.utils.config_utils import get_dtype
2626
from lightllm.utils.config_utils import get_config_json
2727

28+
2829
def get_model(weight_dir):
2930
model_cfg = get_config_json(weight_dir)
3031
model_type = model_cfg["model_type"]
@@ -68,7 +69,7 @@ def get_model(weight_dir):
6869

6970
class TestModelInfer(unittest.TestCase):
7071
def test_model_infer(self):
71-
model_dir = "/nvme/ci_performance/models/DeepSeek-V2-Lite-Chat/"
72+
model_dir = "/nvme/models/llama3/Meta-Llama-3-8B/"
7273
model_class = get_model(model_dir)
7374
data_type = get_dtype(model_dir)
7475
mode = "triton_gqa_flashdecoding"

0 commit comments

Comments
 (0)