Skip to content

Conversation

@Sulfur6
Copy link

@Sulfur6 Sulfur6 commented Aug 26, 2025

1. Motivation

The optimization effect of Two-Batch Overlap (TBO) is suboptimal for the Decode phase on low-compute-power cards (i.e., H20). This is due to two main factors: First, on the Hopper architecture, the WGMMA block_m is 64. Consequently, when TBO is enabled with a small Decode batch size, the MLP GEMM suffers from redundant computations. A positive throughput gain is only observed at larger batch sizes (e.g., 64, 128). Second, at these larger batch sizes, low-compute-power cards like the H20 fail to meet the SLA guarantees for TPOT/ITL.

Therefore, it is necessary to find a solution that can improve Decode throughput even with small batch sizes. Single Batch Overlap (SBO) presents itself as a viable solution.

We implement SBO for DeepSeek v3/R1 by modifying DeepEP and DeepGEMM, including the overlap of Shared Expert and Dispatch Recv, as well as the overlap of Down GEMM with Combine Send.

The overlap of Down GEMM with Combine Send is implemented by modifying DeepEP and DeepGEMM, with the detailed implementation available in the branches below:

Since the latest version of SGLang depends on the branch https://github.com/sgl-project/DeepGEMM/tree/sgl, you should not use the branch specified by the above PR when starting SGLang. Instead, you should use the branch developed based on the sgl branch https://github.com/Sulfur6/DeepGEMM/tree/sbo.v2.sgl

2. Overlap Design

SBO implements two overlap for the MoE layers of DeepSeek-V3/R1. One is to overlap the Shared Expert computation with the Dispatch Recv communication, and the other is to overlap the Down GEMM computation with the Combine Send communication.
image
The interaction between Down GEMM and Combine Send is structured as a producer-consumer model synchronized by signals. For each local expert, a signal unit is allocated for every block_m tokens. The Down GEMM computes the results for these block_m tokens and atomically increments the signaling unit after completing a portion of the work. The Combine Send polls this signaling unit. Once the value reaches a threshold, it sends the corresponding block_m tokens.
image

3. Modifications

New Server Arguments

  • --enable-single-batch-overlap: add this argument to enable SBO (Single Batch Overlap).
  • --deepep-mode low_latency_overlap: Newly added deepep mode, mainly for SBO.

Add assertion to make sure that when SBO is enabled, --moe-a2a-backed must be "deepep" and --deepep-mode must be "auto" for mix or "low_latency_overlap" for pd disaggregation.

EPLB manager

Unify eplb distribution recorder for both low latency and low latency overlap.

Deepep Token Dispatcher

Add new dispatcher implement _DeepEPDispatcherImplLowLatencyOverlap for deepep mode "low latency overlap" for SBO.
Add new parameters for DeepEPDispatcher.

DeepGEMM Wrapper

Add wrapper for masked signal gemm for SBO.

Model & Layer

  • Add forward_deepep_sbo in DeepseekV2MoE of deepseek_v2 model
  • Add forward_deepgemm_signal for related DeepEPMoE layer.
  • Move forward shared experts between dispatch send and recv.

4. Evaluation

4.1. Experiment Setup

  • 5 nodes, with 8 × H20 GPUs per node. Each prefill node uses TP8, and the other 2 decode nodes use DP_Attn 16 + EP 16.
  • Input length 4096, output length 1536.

4.2. Performance Evaluation

  • bs 32, origin
image
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    4.8
Max request concurrency:                 512
Successful requests:                     10240
Benchmark duration (s):                  2359.16
Total input tokens:                      41943040
Total generated tokens:                  15728640
Total generated tokens (retokenized):    15672509
Request throughput (req/s):              4.34
Input token throughput (tok/s):          17778.82
Output token throughput (tok/s):         6667.06
Total token throughput (tok/s):          24445.88
Concurrency:                             490.01
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   112892.31
Median E2E Latency (ms):                 113847.19
---------------Time to First Token----------------
Mean TTFT (ms):                          640.62
Median TTFT (ms):                        545.06
P99 TTFT (ms):                           1543.37
---------------Inter-Token Latency----------------
Mean ITL (ms):                           73.11
Median ITL (ms):                         71.81
P95 ITL (ms):                            86.02
P99 ITL (ms):                            155.32
Max ITL (ms):                            1543.26
==================================================
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    5.0
Max request concurrency:                 512
Successful requests:                     10240
Benchmark duration (s):                  2357.80
Total input tokens:                      41943040
Total generated tokens:                  15728640
Total generated tokens (retokenized):    15673361
Request throughput (req/s):              4.34
Input token throughput (tok/s):          17789.05
Output token throughput (tok/s):         6670.89
Total token throughput (tok/s):          24459.95
Concurrency:                             490.83
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   113015.97
Median E2E Latency (ms):                 113951.58
---------------Time to First Token----------------
Mean TTFT (ms):                          724.98
Median TTFT (ms):                        624.73
P99 TTFT (ms):                           1693.64
---------------Inter-Token Latency----------------
Mean ITL (ms):                           73.13
Median ITL (ms):                         71.84
P95 ITL (ms):                            86.57
P99 ITL (ms):                            155.21
Max ITL (ms):                            1081.95
==================================================
  • bs 32, sbo
image
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    4.8
Max request concurrency:                 512
Successful requests:                     10240
Benchmark duration (s):                  2211.76
Total input tokens:                      41943040
Total generated tokens:                  15728640
Total generated tokens (retokenized):    15673456
Request throughput (req/s):              4.63
Input token throughput (tok/s):          18963.67
Output token throughput (tok/s):         7111.38
Total token throughput (tok/s):          26075.05
Concurrency:                             481.58
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   104017.64
Median E2E Latency (ms):                 105363.65
---------------Time to First Token----------------
Mean TTFT (ms):                          606.28
Median TTFT (ms):                        508.61
P99 TTFT (ms):                           1475.44
---------------Inter-Token Latency----------------
Mean ITL (ms):                           67.35
Median ITL (ms):                         66.10
P95 ITL (ms):                            81.58
P99 ITL (ms):                            141.96
Max ITL (ms):                            1422.74
==================================================
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    5.0
Max request concurrency:                 512
Successful requests:                     10240
Benchmark duration (s):                  2194.12
Total input tokens:                      41943040
Total generated tokens:                  15728640
Total generated tokens (retokenized):    15672577
Request throughput (req/s):              4.67
Input token throughput (tok/s):          19116.14
Output token throughput (tok/s):         7168.55
Total token throughput (tok/s):          26284.70
Concurrency:                             487.92
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   104545.42
Median E2E Latency (ms):                 105483.50
---------------Time to First Token----------------
Mean TTFT (ms):                          619.03
Median TTFT (ms):                        511.23
P99 TTFT (ms):                           1504.27
---------------Inter-Token Latency----------------
Mean ITL (ms):                           67.68
Median ITL (ms):                         66.44
P95 ITL (ms):                            82.13
P99 ITL (ms):                            142.48
Max ITL (ms):                            1024.85
==================================================

4.3. Accuracy Tests

  • bs 32, origin
#python -m benchmark.gsm8k.bench_sglang --port 8000 --num-questions 1000
100%|█████████████████████████████████████████████████████████████| 1000/1000 [01:20<00:00, 12.41it/s]
Accuracy: 0.951
Invalid: 0.000
Latency: 80.802 s
Output throughput: 1183.468 token/s
  • bs 32, sbo
#python -m benchmark.gsm8k.bench_sglang --port 8000 --num-questions 1000
100%|█████████████████████████████████████████████████████████████| 1000/1000 [01:17<00:00, 12.87it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 78.056 s
Output throughput: 1217.443 token/s

4.4. Repro Script

#------------------------------------------- Variables For PD start -------------------------------------------#
# Configuration for PD disaggregation.
MODEL_PATH=""  # Your own model path.
WORK_DIR=""  # Your own work dir.

IB_DEVICE="bond_0,bond_1,bond_2,bond_3"  # Your own IB device names.

NUM_PREFILL=3  # Number of prefill nodes.
NUM_DECODE=2  # Number of decode nodes.

# Example array of IP addresses for the cluster nodes.
NODE_IPS=( \
  "192.168.1.1" \
  "192.168.1.2"  \
  "192.168.1.3" \
  "192.168.1.4" \
  "192.168.1.5"
)

PREFILL_MAIN_IP="192.168.1.1"  # IP of the main node for prefill.
DECODE_MAIN_IP="192.168.1.4"  # IP of the main node for decode (offset by NUM_PREFILL).

DEFAULT_PORT=61001  # Default port for SGLang services.
MAIN_PORT=62001  # Port for the main node communication.
MINI_LB_PORT=8000  # Port for the load balancer.
MAX_RUNNING_REQUEST_PER_GPU=32  # Maximum concurrent requests per GPU.
CHUNKED_PREFILL_SIZE_PER_DP_RANK=4096  # Size of chunked prefill per data-parallel rank.

DP_SIZE_PER_PREFILL_NODE=8  # Data-parallel size for each prefill node.
DP_SIZE_PER_DECODE_NODE=8  # Data-parallel size for each decode node.

PAGE_SIZE=$(( 64 ))
PREFILL_ATTENTION_BACKEND="fa3"  # Attention backend for prefill.
DECODE_ATTENTION_BACKEND="flashmla"  # Attention backend for decode.

# Calculate maximum requests per data-parallel rank based on GPU capacity.
MAX_RUNNING_REQUEST_PER_DP_RANK=$(( MAX_RUNNING_REQUEST_PER_GPU * GPUS_PER_NODE / DP_SIZE_PER_DECODE_NODE ))
CUDA_GRAPH_MAX_BATCH_SIZE=$MAX_RUNNING_REQUEST_PER_DP_RANK  # Set CUDA graph batch size to match max requests per rank.

#---------------------------------------- For Prefill Nodes Start -------------------------------------------#
SGL_ENABLE_JIT_DEEPGEMM=1 \
nohup python3 -m sglang.launch_server \
--trust-remote-code \
--model-path ${MODEL_PATH} \
--disaggregation-mode prefill \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device ${IB_DEVICE} \
--host 0.0.0.0 \
--port ${DEFAULT_PORT} \
--tp-size 8 \
--page-size ${PAGE_SIZE} \
--attention-backend ${PREFILL_ATTENTION_BACKEND} \
--mem-fraction-static 0.92 \
--chunked-prefill-size 32768 \
--max-running-requests $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * MAX_RUNNING_REQUEST_PER_DP_RANK)) \
--max-total-tokens 131076 \
--context-length 65535 \
--enable-cache-report \
--log-level info \
> ${WORK_DIR}/stdout.log 2>&1 &

#---------------------------------------- For Decode Main Node Start -------------------------------------------#
SGL_ENABLE_JIT_DEEPGEMM=1 \
nohup python3 -m sglang.launch_server \
--model-path ${MODEL_PATH} \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device ${IB_DEVICE} \
--attention-backend ${DECODE_ATTENTION_BACKEND} \
--host 0.0.0.0 \
--port ${DEFAULT_PORT} \
--trust-remote-code \
--dist-init-addr ${DECODE_MAIN_IP}:${MAIN_PORT} \
--nnodes ${NUM_DECODE} \
--node-rank 0 \
--tp-size $((GPUS_PER_NODE * NUM_DECODE)) \
--dp-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE)) \
--enable-dp-attention \
--mem-fraction-static 0.88 \
--chunked-prefill-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * CHUNKED_PREFILL_SIZE_PER_DP_RANK)) \
--max-running-requests $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * MAX_RUNNING_REQUEST_PER_DP_RANK)) \
--context-length 32768 \
--log-level info \
--decode-log-interval 50 \
--page-size ${PAGE_SIZE} \
--schedule-conservativeness 0.3 \
--enable-cache-report \
--moe-dense-tp-size 1 \
--enable-dp-lm-head \
--cuda-graph-max-bs ${CUDA_GRAPH_MAX_BATCH_SIZE} \
--load-balance-method minimum_tokens \
--moe-a2a-backend deepep \
--enable-single-batch-overlap \
--deepep-mode low_latency_overlap \
> ${WORK_DIR}/stdout.log 2>&1 &

#---------------------------------------- For Decode Worker Node Start -------------------------------------------#
SGL_ENABLE_JIT_DEEPGEMM=1 \
nohup python3 -m sglang.launch_server \
--model-path ${MODEL_PATH} \
--disaggregation-mode decode \
--disaggregation-transfer-backend mooncake \
--disaggregation-ib-device ${IB_DEVICE} \
--attention-backend ${DECODE_ATTENTION_BACKEND} \
--host 0.0.0.0 \
--port ${DEFAULT_PORT} \
--trust-remote-code \
--dist-init-addr ${DECODE_MAIN_IP}:${MAIN_PORT} \
--nnodes ${NUM_DECODE} \
--node-rank 1 \
--tp-size $((GPUS_PER_NODE * NUM_DECODE)) \
--dp-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE)) \
--enable-dp-attention \
--mem-fraction-static 0.88 \
--chunked-prefill-size $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * CHUNKED_PREFILL_SIZE_PER_DP_RANK)) \
--max-running-requests $((DP_SIZE_PER_DECODE_NODE * NUM_DECODE * MAX_RUNNING_REQUEST_PER_DP_RANK)) \
--context-length 32768 \
--log-level info \
--decode-log-interval 50 \
--page-size ${PAGE_SIZE} \
--schedule-conservativeness 0.3 \
--enable-cache-report \
--moe-dense-tp-size 1 \
--enable-dp-lm-head \
--cuda-graph-max-bs ${CUDA_GRAPH_MAX_BATCH_SIZE} \
--load-balance-method minimum_tokens \
--moe-a2a-backend deepep \
--enable-single-batch-overlap \
--deepep-mode low_latency_overlap \
> ${WORK_DIR}/stdout.log 2>&1 &

#--------------------------------------- For mini_lb start -------------------------------------------#
nohup python3 -m sglang_router.launch_router \
  --pd-disaggregation \
  --host 0.0.0.0 \
  --port 8000 \
  --decode http://192.168.1.4:61001 \
  --prefill http://192.168.1.1:61001 \
  --prefill http://192.168.1.2:61001 \
  --prefill http://192.168.1.3:61001

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 26, 2025

Interesting, I am also doing some overlapping recently. Do we need to use a changed DeepEP or DeepGEMM?

@Sulfur6
Copy link
Author

Sulfur6 commented Aug 27, 2025

Interesting, I am also doing some overlapping recently. Do we need to use a changed DeepEP or DeepGEMM?

@fzyzcjy Thanks for the comments. To overlap the Down GEMM with the Combine Send, we have modified the DeepGEMM and DeepEP respectively. As for the Shared Expert and Dispatch Recv overlap, that only required modifications to SGLang. We are currently cleaning up the code for DeepGEMM and DeepEP and will submit PRs within the next two days.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 27, 2025

Sure, I mean you need to paste the corresponding deepgemm/deepep branches as well (when ready).

@Sulfur6
Copy link
Author

Sulfur6 commented Aug 27, 2025

Sure, I mean you need to paste the corresponding deepgemm/deepep branches as well (when ready).

@fzyzcjy We updated the PR and added our modified DeepEP and DeepGEMM branches:

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 27, 2025

get, btw the speedup looks like only ~1%, thus I am curious whether it is because the overlappable region is tiny or the overhead of overlap is large, and also how much does the SBO improve over the simple standard overlap-shared-with-comunication. could you please share a pair of profile (one w/o overlap, one w/ overlap) about them?

@Sulfur6
Copy link
Author

Sulfur6 commented Aug 27, 2025

get, btw the speedup looks like only ~1%, thus I am curious whether it is because the overlappable region is tiny or the overhead of overlap is large, and also how much does the SBO improve over the simple standard overlap-shared-with-comunication. could you please share a pair of profile (one w/o overlap, one w/ overlap) about them?

@fzyzcjy Thank you for your reminder. We pasted the wrong result when creating the draft, and have now updated it to the correct one.

@Sulfur6
Copy link
Author

Sulfur6 commented Aug 27, 2025

get, btw the speedup looks like only ~1%, thus I am curious whether it is because the overlappable region is tiny or the overhead of overlap is large, and also how much does the SBO improve over the simple standard overlap-shared-with-comunication. could you please share a pair of profile (one w/o overlap, one w/ overlap) about them?

@fzyzcjy We recorded the profiles with and without overlap when the batch size was 32. Below is a screenshot of the profile of a single DeepseekV2Decoder layer on DP0_TP0_EP0 on the decode node:

  • w/o overlap
image
  • w/ overlap
image

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 27, 2025

I see, yes that looks reasonable on your H20 hardware (I do not have H20 and thus dnk the time spent of each kernel before)

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Aug 28, 2025

btw, briefly glanced and it seems atomicAdd is used to send signals, thus curious whether this memory ordering and send location is strong enough

image

@Sulfur6
Copy link
Author

Sulfur6 commented Aug 28, 2025

Since sglang has merged PR: #9340 to upgrade to DeepGEMM v2, we are working on the relevant adaptation work.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 2, 2025

this change looks great, but I am still a bit worried (1) shall we use atomicAdd (doc says relaxed ordering) or use released ordering (2) will the extra tma store wait make that warp group slower (i.e. shall we signal on the next existing tma store wait).

FYI my naive implementations are in flashinfer-ai/flashinfer#1569 (have not tested it since the nvfp4 code path has not arrived yet...)

image

@Sulfur6
Copy link
Author

Sulfur6 commented Sep 2, 2025

this change looks great, but I am still a bit worried (1) shall we use atomicAdd (doc says relaxed ordering) or use released ordering (2) will the extra tma store wait make that warp group slower (i.e. shall we signal on the next existing tma store wait).

FYI my naive implementations are in flashinfer-ai/flashinfer#1569 (have not tested it since the nvfp4 code path has not arrived yet...)

image

@fzyzcjy For (1), we will conduct a more in-depth investigation. For (2), after our tests, tma_store_wait<0>() does not bring additional overhead. We speculate that this is because tma_store_wait<0>() must also be executed before tma_store_arrive(). However, __threadfence() will bring a certain performance overhead of ~4% to down gemm (in the EP16 scenario). But we believe this is necessary to ensure memory ordering.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 3, 2025

FYI I am waiting for the refactored deepgemm (hopper), since I need to implement deepgemm blackwell and want to be aligned with your style to avoid two conflicting styles

@Sulfur6
Copy link
Author

Sulfur6 commented Sep 3, 2025

FYI I am waiting for the refactored deepgemm (hopper), since I need to implement deepgemm blackwell and want to be aligned with your style to avoid two conflicting styles

@fzyzcjy We have submitted a pull request to DeepGEMM.v2 deepseek-ai/DeepGEMM#183, which contains the GEMM interface and implementation required for overlap. We would like to know if you have any suggestions for modification.

@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Sep 3, 2025

@Sulfur6 I made a tiny nit comment there

@AniZpZ AniZpZ changed the title Single Batch Overlap for MoE Models [WIP]Single Batch Overlap for MoE Models Sep 3, 2025
@AniZpZ AniZpZ marked this pull request as ready for review September 3, 2025 09:40
@Fridge003
Copy link
Collaborator

Hi @Sulfur6 , how is the progress for this PR, is it ready for merge?

@Sulfur6
Copy link
Author

Sulfur6 commented Nov 27, 2025

Hi @Sulfur6 , how is the progress for this PR, is it ready for merge?

@Fridge003 Currently the local test has passed and I am debugging CI. This PR depends on sgl-project/DeepGEMM#14, which is ready and waiting for approve and merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants