Skip to content

Conversation

@tsdocode
Copy link
Contributor

@tsdocode tsdocode commented May 27, 2025

Did in PR:

  1. Split Attention to 2 class: SelfAttention and CrossAttention for further optimization
  2. Add Fused QKV: Operate q_proj, k_proj, v_proj in one layer
  3. Enhance Rope: reuse cos, sin for self attention
  4. Add adjust KV cache max_length == max_token_lengths instead of always set to model's max length => less VRAM use, slightly faster

Test scripts:

from random import choice

import torch

from dia.model import Dia


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)

model_name = "nari-labs/Dia-1.6B"
compute_dtype = "float16"

model = Dia.from_pretrained(model_name, compute_dtype=compute_dtype)


for idx in range(len(model.model.decoder.layers)):
    layer = model.model.decoder.layers[idx]
    layer.self_attention.patch_fused_qkv()


test_cases = [
    "[S1] Dia is an open weights text to dialogue model.",
    "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face.",
    "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code.",
    "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code. [S2] It is a new feature in PyTorch that allows you to compile your model with a single line of code.",
]


use_torch_compile = True

MAX_TOKENS = 86*5

# # # Wram up
for _ in range(2):
    text = choice(test_cases)
    output = model.generate(text, use_torch_compile=use_torch_compile, verbose=True, max_tokens=MAX_TOKENS)

text = choice(test_cases)

# Benchmark
for i in range(10):
    output = model.generate(text, use_torch_compile=use_torch_compile, verbose=True, max_tokens=MAX_TOKENS)
    text = choice(test_cases)

Result on A100 80Gb:

~216 token/s => ~232 token/s

Other room for speed-up:

  • Upon analyzing the flame graph, a significant 20% gap has been identified in the token generation time. This gap occurs between the GPU launch (_decode_step) and the sampling phase (computing the next token and checking for ending conditions), indicating a potential area for optimization.
Screenshot 2025-05-28 at 00 51 26 Screenshot 2025-05-28 at 00 51 18

@V12Hero
Copy link
Contributor

V12Hero commented May 27, 2025

Could you please sync your repo with the original? Yours is 16 commits behind, and it doesn't work on a MacBook due to some issues I fixed later, which aren't included in your repo. I'm mentioning this because I want to test it.

@tsdocode
Copy link
Contributor Author

Screenshot 2025-05-28 at 09 56 30 I synced the latest code, please help me check it

Copy link
Collaborator

@buttercrab buttercrab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@buttercrab
Copy link
Collaborator

Could you fix the lint & format?

@buttercrab buttercrab merged commit cb07e05 into nari-labs:main May 28, 2025
1 check passed
@V12Hero
Copy link
Contributor

V12Hero commented May 28, 2025

@tsdocode I know the issue is closed, but I did get to test it. I saw a ~30% increase in processing on my M3 MacBook Pro with 36GB VRAM. It's a huge difference, very noticeable as well.

Here are the logs from my recent run of example/simple-mac.py:

generate step 86: speed=6.268 tokens/s, realtime factor=0.073x
generate step 172: speed=11.486 tokens/s, realtime factor=0.134x
generate step 258: speed=11.468 tokens/s, realtime factor=0.133x
generate step 344: speed=11.472 tokens/s, realtime factor=0.133x
generate step 430: speed=11.379 tokens/s, realtime factor=0.132x
generate step 516: speed=11.226 tokens/s, realtime factor=0.131x
generate step 602: speed=11.396 tokens/s, realtime factor=0.133x
generate step 688: speed=11.337 tokens/s, realtime factor=0.132x
generate: avg steps=758.0, total duration=75.467s

And below are the logs of a previous run for the same script with the previous version:

generate: starting generation loop
generate step 86: speed=7.759 tokens/s, realtime factor=0.090x
generate step 172: speed=8.470 tokens/s, realtime factor=0.098x
generate step 258: speed=8.536 tokens/s, realtime factor=0.099x
generate step 344: speed=8.489 tokens/s, realtime factor=0.099x
generate step 430: speed=8.607 tokens/s, realtime factor=0.100x
generate step 516: speed=8.615 tokens/s, realtime factor=0.100x
generate step 602: speed=8.592 tokens/s, realtime factor=0.100x
generate step 688: speed=8.597 tokens/s, realtime factor=0.100x
generate: avg steps=747.0, total duration=91.026s

@tsdocode
Copy link
Contributor Author

@V12Hero Add this to you code before running generate, this fused qkv operation, maybe it will help speedup a little more:

for idx in range(len(model.model.decoder.layers)):
    layer = model.model.decoder.layers[idx]
    layer.self_attention.patch_fused_qkv()

# generate code

@V12Hero
Copy link
Contributor

V12Hero commented May 28, 2025

@tsdocode is this how you would want it placed?

from dia.model import Dia


model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")

text = "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face."

for idx in range(len(model.model.decoder.layers)):
    layer = model.model.decoder.layers[idx]
    layer.self_attention.patch_fused_qkv()
    
# It is important to set the `use_torch_compile` argument to `False` when using Dia on MacOS.
# This is because the `torch.compile` function is not supported on MacOS.
output = model.generate(text, use_torch_compile=False, verbose=True)

model.save_audio("simple.mp3", output)

@tsdocode
Copy link
Contributor Author

Yes

@V12Hero
Copy link
Contributor

V12Hero commented May 28, 2025

Ok, I can't see any difference, the slight dip in the performance is because I'm now running a bit low on battery but the overall performance I'd say is the same

generate: starting generation loop
generate step 86: speed=9.641 tokens/s, realtime factor=0.112x
generate step 172: speed=11.205 tokens/s, realtime factor=0.130x
generate step 258: speed=11.287 tokens/s, realtime factor=0.131x
generate step 344: speed=11.277 tokens/s, realtime factor=0.131x
generate step 430: speed=11.266 tokens/s, realtime factor=0.131x
generate step 516: speed=11.194 tokens/s, realtime factor=0.130x
generate step 602: speed=11.253 tokens/s, realtime factor=0.131x
generate step 688: speed=11.281 tokens/s, realtime factor=0.131x
generate step 774: speed=11.173 tokens/s, realtime factor=0.130x
generate: avg steps=772.0, total duration=71.836s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants