Skip to content

Commit 95ac241

Browse files
committed
initial tests
Signed-off-by: jthomson04 <[email protected]>
1 parent 9aebd26 commit 95ac241

File tree

4 files changed

+31
-30
lines changed

4 files changed

+31
-30
lines changed

components/backends/trtllm/performance_sweeps/benchmark_disagg.slurm

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ srun -l --container-name=${CONTAINER_NAME} \
8585
--num_ctx_servers ${num_ctx_servers} \
8686
--ctx_tp_size ${ctx_tp_size} \
8787
--ctx_ep_size ${ctx_ep_size} \
88-
--ctx_enable_attention_dp ${ctx_enable_attention_dp} \
8988
--ctx_batch_size ${ctx_batch_size} \
9089
--ctx_max_num_tokens ${ctx_max_num_tokens} \
9190
--ctx_max_seq_len ${ctx_max_seq_len} \
@@ -180,7 +179,7 @@ for ((i=1; i<=DECODE_COUNT; i++)); do
180179
--ntasks $gen_tp_size \
181180
--oversubscribe \
182181
--overlap \
183-
bash ${SCRIPTS_DIR}/scripts/start_disagg_worker.sh ${full_logdir}/decode_config.yaml ${ctx_gpus} ${nsys_on} ${served_model_name} ${model_path} 'decode' &> ${full_logdir}/output_decode_worker_${i}.log &
182+
bash ${SCRIPTS_DIR}/scripts/start_disagg_worker.sh ${full_logdir}/decode_config.yaml ${ctx_gpus} ${served_model_name} ${model_path} 'decode' &> ${full_logdir}/output_decode_worker_${i}.log &
184183
echo "$!" >> "$PID_FILE"
185184
done
186185

@@ -203,9 +202,9 @@ for ((i=1; i<=PREFILL_COUNT; i++)); do
203202
--mpi=pmix --overlap -w ${nodes[node_idx]} \
204203
--oversubscribe \
205204
--overlap \
206-
--ntasks $(( tp_size < 4 ? tp_size : 4 )) \
205+
--ntasks $(( ctx_tp_size < 4 ? ctx_tp_size : 4 )) \
207206
--nodes 1 \
208-
bash ${SCRIPTS_DIR}/scripts/start_disagg_worker.sh ${full_logdir}/prefill_config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} ${served_model_name} ${model_path} 'prefill' &> ${full_logdir}/output_prefill_worker_${i}.log &
207+
bash ${SCRIPTS_DIR}/scripts/start_disagg_worker.sh ${full_logdir}/prefill_config.yaml ${ctx_gpus} ${served_model_name} ${model_path} 'prefill' &> ${full_logdir}/output_prefill_worker_${i}.log &
209208
prefill_pids+=($!)
210209
echo "$!" >> "$PID_FILE"
211210
done

components/backends/trtllm/performance_sweeps/scripts/gen_yaml.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,14 @@ class ModelType(Enum):
1515
"""
1616
GPT_OSS = "gpt_oss"
1717
DSR1 = "dsr1"
18-
19-
def infer_model_type(self, model_path: str) -> str:
20-
if "r1" in model_path.lower():
21-
return self.DSR1
22-
else:
23-
return self.GPT_OSS
24-
25-
CONFIG_MAPPING = {
26-
ModelType.GPT_OSS: None,
27-
ModelType.DSR1: generate_dsr1_config,
28-
}
18+
19+
def get_model_type(model_path: str) -> str:
20+
if "r1" in model_path.lower():
21+
print("Inferring DSR1-type model")
22+
return ModelType.DSR1
23+
else:
24+
print("Inferring GPT-oss-type model")
25+
return ModelType.GPT_OSS
2926

3027
def generate_dsr1_config(
3128
config_path: str,
@@ -89,7 +86,7 @@ def generate_dsr1_config(
8986
"max_seq_len": args.gen_max_seq_len,
9087
"cuda_graph_config": {
9188
"enable_padding": True,
92-
"batch_sizes": args.gen_cuda_graph_batch_sizes,
89+
"batch_sizes": gen_cuda_graph_batch_sizes,
9390
},
9491
"print_iter_log": True,
9592
"kv_cache_config": {
@@ -160,7 +157,7 @@ def generate_gpt_oss_config(
160157
768,
161158
1024,
162159
2048,
163-
gen_batch_size,
160+
args.gen_batch_size,
164161
]
165162

166163
gen_moe_backend = "TRTLLM"
@@ -210,7 +207,7 @@ def generate_gpt_oss_config(
210207
"max_seq_len": args.gen_max_seq_len,
211208
"cuda_graph_config": {
212209
"enable_padding": True,
213-
"batch_sizes": args.gen_cuda_graph_batch_sizes,
210+
"batch_sizes": gen_cuda_graph_batch_sizes,
214211
},
215212
"print_iter_log": True,
216213
"kv_cache_config": {
@@ -257,6 +254,11 @@ def generate_gpt_oss_config(
257254

258255
return prefill_config, decode_config
259256

257+
CONFIG_MAPPING = {
258+
ModelType.GPT_OSS: generate_gpt_oss_config,
259+
ModelType.DSR1: generate_dsr1_config,
260+
}
261+
260262
def process_node_and_task() -> tuple[int, List[str], List[str]]:
261263
"""
262264
Process SLURM node and task environment variables.
@@ -429,7 +431,7 @@ def gen_config_file(
429431
server_port: Server port
430432
"""
431433

432-
model_type = ModelType.get_model_type(model_path)
434+
model_type = get_model_type(model_path)
433435

434436
prefill_config, decode_config = CONFIG_MAPPING[model_type](
435437
config_path,
@@ -471,12 +473,6 @@ def gen_config_file(
471473
required=True,
472474
help="Expert parallel size for context servers",
473475
)
474-
parser.add_argument(
475-
"--ctx_enable_attention_dp",
476-
dest="ctx_enable_attention_dp",
477-
action="store_true",
478-
help="Enable attention DP for context servers",
479-
)
480476
parser.add_argument(
481477
"--ctx_batch_size",
482478
type=int,
@@ -519,6 +515,12 @@ def gen_config_file(
519515
required=True,
520516
help="Tensor parallel size for generation servers",
521517
)
518+
parser.add_argument(
519+
"--gen_ep_size",
520+
type=int,
521+
required=True,
522+
help="Expert parallel size for generation servers",
523+
)
522524
parser.add_argument(
523525
"--gen_batch_size",
524526
type=int,

components/backends/trtllm/performance_sweeps/scripts/start_disagg_worker.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
config_file=$1
6-
ctx_gpus=$3
7-
model_name=$4
8-
model_path=$5
9-
disaggregation_mode=$6
6+
ctx_gpus=$2
7+
model_name=$3
8+
model_path=$4
9+
disaggregation_mode=$5
1010
unset UCX_TLS
1111
echo "config_file: ${config_file}, ctx_gpus: ${ctx_gpus}, disaggregation_mode: ${disaggregation_mode}"
1212

components/backends/trtllm/performance_sweeps/submit_disagg.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ run_single() {
9898
total_nodes=$((ctx_num + gen_nodes))
9999
total_tasks=$((total_nodes * 4))
100100
set -x
101-
sbatch --nodes=${total_nodes} --ntasks=${total_tasks} --ntasks-per-node=${NTASKS_PER_NODE} --segment=${total_nodes} ${slurm_args} benchmark_disagg.slurm ${ctx_num} ${ctx_tp_size} ${ctx_ep_size} ${ctx_enable_attention_dp} 30 20000 ${gen_num} ${gen_tp_size} ${gen_batch_size} ${gen_max_num_tokens} ${gen_enable_attention_dp} ${gen_gpu_memory_fraction} ${gen_eplb_num_slots} ${gen_mtp_size} "${gen_concurrency_list}" ${gen_nodes} ${kind} ${MODEL_PATH} ${SERVED_MODEL_NAME} ${IMAGE} ${ISL} ${OSL}
101+
sbatch --nodes=${total_nodes} --ntasks=${total_tasks} --ntasks-per-node=${NTASKS_PER_NODE} --segment=${total_nodes} ${slurm_args} benchmark_disagg.slurm ${ctx_num} ${ctx_tp_size} ${ctx_ep_size} ${ctx_enable_attention_dp} 30 20000 ${gen_num} ${gen_tp_size} ${gen_ep_size} ${gen_batch_size} ${gen_max_num_tokens} ${gen_enable_attention_dp} ${gen_gpu_memory_fraction} ${gen_eplb_num_slots} ${gen_mtp_size} "${gen_concurrency_list}" ${gen_nodes} ${kind} ${MODEL_PATH} ${SERVED_MODEL_NAME} ${IMAGE} ${ISL} ${OSL}
102102
set +x
103103
}
104104

0 commit comments

Comments
 (0)