Skip to content

Commit c7cba6d

Browse files
committed
chore: format with ruff
1 parent 0211e3a commit c7cba6d

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

dia/layers.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,12 @@ def patch_fused_qkv(self):
430430

431431
self.qkv = FusedQKV(
432432
self.kv_embed_dim,
433-
(
434-
self.num_query_heads * self.head_dim
435-
+ 2 * (self.num_kv_heads * self.head_dim)
436-
),
433+
(self.num_query_heads * self.head_dim + 2 * (self.num_kv_heads * self.head_dim)),
437434
bias=False,
438435
num_q_heads=self.num_query_heads,
439436
q_head_dim=self.head_dim,
440437
num_kv_heads=self.num_kv_heads,
441-
kv_head_dim=self.head_dim
438+
kv_head_dim=self.head_dim,
442439
)
443440
self.qkv.linear.weight.data = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
444441

dia/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,12 @@ def _prepare_generation(
372372
encoder_out, enc_state.positions, enc_state.padding_mask
373373
)
374374
dec_state = DecoderInferenceState.new(
375-
self.config, enc_state, encoder_out, dec_cross_attn_cache, self.compute_dtype, max_generation_length=max_tokens
375+
self.config,
376+
enc_state,
377+
encoder_out,
378+
dec_cross_attn_cache,
379+
self.compute_dtype,
380+
max_generation_length=max_tokens,
376381
)
377382
prefill, prefill_steps = self._prepare_audio_prompt(audio_prompts)
378383

0 commit comments

Comments
 (0)