@@ -38,16 +38,20 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_
3838
3939def tppart_model_infer (model_class , model_kvargs , batch_size , input_len , output_len , ans_queue ):
4040 import torch
41- from lightllm .distributed import set_custom_reduce
41+ from lightllm .distributed import custom_comm_ops
42+ from lightllm .utils .device_utils import set_current_device_id
43+
4244 import torch .distributed as dist
4345
4446 rank_id = model_kvargs ["tp_rank" ]
4547 world_size = model_kvargs ["world_size" ]
4648
4749 torch .cuda .set_device (rank_id )
48-
50+ set_current_device_id ( rank_id )
4951 dist .init_process_group ("nccl" , init_method = "tcp://127.0.0.1:28765" , rank = rank_id , world_size = world_size )
50- set_custom_reduce ()
52+
53+ custom_comm_ops .set_custom_reduce ()
54+ custom_comm_ops .set_custom_gather ()
5155 dist .barrier ()
5256
5357 torch .cuda .empty_cache ()
@@ -59,7 +63,9 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
5963 test_data = test_data .reshape (- 1 )
6064 test_data = torch .from_numpy (test_data ).cuda ()
6165
62- b_req_idx = model_part .req_manager .alloc (batch_size ).int ()
66+ b_req_idx = torch .tensor (
67+ [model_part .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
68+ )
6369 b_start_loc = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
6470 b_seq_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
6571 b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
@@ -68,7 +74,8 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
6874 b_seq_len [i ] = input_len
6975
7076 total_token_num = input_len * batch_size
71- mem_indexes = model_part .req_manager .mem_manager .alloc (test_data .shape [0 ])
77+ mem_indexes = model_part .req_manager .mem_manager .alloc (test_data .shape [0 ]).cuda ()
78+
7279 logics = model_part .forward (
7380 batch_size ,
7481 total_token_num ,
@@ -89,7 +96,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
8996 b_start_loc = b_start_loc + torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
9097 total_token_num += batch_size
9198 b_seq_len += 1
92- mem_indexes = model_part .req_manager .mem_manager .alloc (predict_ids .shape [0 ])
99+ mem_indexes = model_part .req_manager .mem_manager .alloc (predict_ids .shape [0 ]). cuda ()
93100 logics = model_part .forward (
94101 batch_size ,
95102 total_token_num ,
@@ -108,10 +115,6 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
108115 model_part .mem_manager .free_all ()
109116 model_part .req_manager .free_all ()
110117
111- if rank_id == 0 :
112- print ("can use mem size:" , model_part .mem_manager .can_use_mem_size )
113- print ("can use req size:" , model_part .req_manager .can_use_req_size )
114-
115118 b_req_idx = None
116119 b_start_loc = None
117120 b_seq_len = None
@@ -124,15 +127,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
124127
125128 prefill_start_time = time .time ()
126129
127- b_req_idx = model_part .req_manager .alloc (batch_size ).int ()
130+ b_req_idx = torch .tensor (
131+ [model_part .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
132+ )
128133 b_start_loc = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
129134 b_seq_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
130135 for i in range (batch_size ):
131136 b_start_loc [i ] = i * input_len
132137 b_seq_len [i ] = input_len
133138
134139 total_token_num = batch_size * input_len
135- mem_indexes = model_part .req_manager .mem_manager .alloc (test_data .shape [0 ])
140+ mem_indexes = model_part .req_manager .mem_manager .alloc (test_data .shape [0 ]). cuda ()
136141 logics = model_part .forward (
137142 batch_size ,
138143 total_token_num ,
@@ -159,7 +164,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
159164 b_start_loc = b_start_loc + torch .arange (0 , batch_size , dtype = torch .int32 , device = "cuda" )
160165 total_token_num += batch_size
161166 b_seq_len += 1
162- mem_indexes = model_part .req_manager .mem_manager .alloc (predict_ids .shape [0 ])
167+ mem_indexes = model_part .req_manager .mem_manager .alloc (predict_ids .shape [0 ]). cuda ()
163168 logics = model_part .forward (
164169 batch_size ,
165170 total_token_num ,
0 commit comments