@@ -291,7 +291,8 @@ void UnfusedAttentionTest::addFusedQKVBiasTransposeTest(size_t batch_size,
291291 params.common .cu_seqlens ->data <int >(),
292292 params.common .cu_seqlens_without_prefix ->data <int >(),
293293 device->use_rope_cache_ ,
294- device->use_rope_cache_ && device->rope_cache_ .defined () ?
294+ device->use_rope_cache_ && device->rope_cache_ .defined ()
295+ && device->rope_cache_dim_ == params.configs .rope_config .dim ?
295296 device->rope_cache_ .data_ptr <float >() :
296297 nullptr ,
297298 batch_size,
@@ -335,7 +336,8 @@ void UnfusedAttentionTest::addFusedQKVBiasTransposeTest(size_t batch_size,
335336 params.common .cu_seqlens ->data <int >(),
336337 params.common .cu_seqlens_without_prefix ->data <int >(),
337338 device->use_rope_cache_ ,
338- device->use_rope_cache_ && device->rope_cache_ .defined () ?
339+ device->use_rope_cache_ && device->rope_cache_ .defined ()
340+ && device->rope_cache_dim_ == params.configs .rope_config .dim ?
339341 device->rope_cache_ .data_ptr <float >() :
340342 nullptr ,
341343 batch_size,
@@ -374,40 +376,42 @@ void UnfusedAttentionTest::addFusedQKVBiasTransposeTest(size_t batch_size,
374376 bool store_kv = true ;
375377 bool store_cache = false ;
376378
377- DISPATCH_CUDA_FUNCTION_DATA_TYPE (
378- params.input .type (),
379- invokeAddFusedQKVBiasTranspose,
380- q_no_transpose_output->data (),
381- q_output->data (),
382- k_output->data (),
383- v_output->data (),
384- &prefix_prompt_param,
385- params.input .data (),
386- qkv_buf_fp8 ? qkv_buf_fp8->data () : nullptr ,
387- params.common .position_ids ->data <int >(),
388- params.weights .qkv_weight ->bias ->data (),
389- params.common .padding_offset ->data <int >(),
390- params.common .cu_seqlens ->data <int >(),
391- params.common .cu_seqlens_without_prefix ->data <int >(),
392- device->use_rope_cache_ ,
393- device->use_rope_cache_ && device->rope_cache_ .defined () ? device->rope_cache_ .data_ptr <float >() : nullptr ,
394- batch_size,
395- seq_len,
396- token_num,
397- num_heads,
398- num_key_value_heads,
399- head_dim,
400- params.configs .rope_config ,
401- params.configs .use_logn_attn ,
402- scale_out_ptr,
403- int8_mode,
404- use_paged_fmha,
405- store_qkv,
406- store_q_no_transpose,
407- store_q,
408- store_kv,
409- store_cache,
410- device->getStream ());
379+ DISPATCH_CUDA_FUNCTION_DATA_TYPE (params.input .type (),
380+ invokeAddFusedQKVBiasTranspose,
381+ q_no_transpose_output->data (),
382+ q_output->data (),
383+ k_output->data (),
384+ v_output->data (),
385+ &prefix_prompt_param,
386+ params.input .data (),
387+ qkv_buf_fp8 ? qkv_buf_fp8->data () : nullptr ,
388+ params.common .position_ids ->data <int >(),
389+ params.weights .qkv_weight ->bias ->data (),
390+ params.common .padding_offset ->data <int >(),
391+ params.common .cu_seqlens ->data <int >(),
392+ params.common .cu_seqlens_without_prefix ->data <int >(),
393+ device->use_rope_cache_ ,
394+ device->use_rope_cache_ && device->rope_cache_ .defined ()
395+ && device->rope_cache_dim_ == params.configs .rope_config .dim ?
396+ device->rope_cache_ .data_ptr <float >() :
397+ nullptr ,
398+ batch_size,
399+ seq_len,
400+ token_num,
401+ num_heads,
402+ num_key_value_heads,
403+ head_dim,
404+ params.configs .rope_config ,
405+ params.configs .use_logn_attn ,
406+ scale_out_ptr,
407+ int8_mode,
408+ use_paged_fmha,
409+ store_qkv,
410+ store_q_no_transpose,
411+ store_q,
412+ store_kv,
413+ store_cache,
414+ device->getStream ());
411415
412416 device->syncAndCheck ();
413417
0 commit comments