Skip to content

Commit 1d392d3

Browse files
committed
fix: ci test
1 parent 337c97a commit 1d392d3

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

rtp_llm/cpp/devices/rocm_impl/ROCmAttentionOp.cc

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -778,42 +778,44 @@ AttentionModuleOutput ROCmDevice::contextAttention(const AttentionModuleParams&
778778
}
779779
check_cuda_error();
780780
} else {
781-
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype,
782-
invokeAddFusedQKVBiasTranspose,
783-
nullptr,
784-
q_output->data(),
785-
k_output->data(),
786-
v_output->data(),
787-
&prefix_prompt_param,
788-
params.input.data(),
781+
DISPATCH_CUDA_FUNCTION_DATA_TYPE(
782+
datatype,
783+
invokeAddFusedQKVBiasTranspose,
784+
nullptr,
785+
q_output->data(),
786+
k_output->data(),
787+
v_output->data(),
788+
&prefix_prompt_param,
789+
params.input.data(),
790+
nullptr,
791+
params.common.position_ids ? params.common.position_ids->dataWithOffset<int>(
792+
decoder_batch_size * params.configs.rope_config.index_factor) :
789793
nullptr,
790-
params.common.position_ids ?
791-
params.common.position_ids->dataWithOffset<int>(
792-
decoder_batch_size * params.configs.rope_config.index_factor) :
793-
nullptr,
794-
params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ?
795-
params.weights.qkv_weight->bias->data() :
796-
nullptr,
797-
params.common.padding_offset->data<int>(),
798-
params.common.cu_seqlens->data<int>(),
799-
params.common.cu_seqlens_without_prefix->data<int>(),
800-
batch_size,
801-
seq_len,
802-
token_num,
803-
head_num,
804-
kv_head_num,
805-
size_per_head,
806-
params.configs.rope_config,
807-
params.configs.use_logn_attn,
808-
scale_out_ptr,
809-
int8_mode,
810-
false,
811-
store_qkv,
812-
false,
813-
store_q,
814-
store_kv,
815-
store_cache,
816-
stream_);
794+
params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ?
795+
params.weights.qkv_weight->bias->data() :
796+
nullptr,
797+
params.common.padding_offset->data<int>(),
798+
params.common.cu_seqlens->data<int>(),
799+
params.common.cu_seqlens_without_prefix->data<int>(),
800+
use_rope_cache_,
801+
use_rope_cache_ && rope_cache_.defined() ? rope_cache_.data_ptr<float>() : nullptr,
802+
batch_size,
803+
seq_len,
804+
token_num,
805+
head_num,
806+
kv_head_num,
807+
size_per_head,
808+
params.configs.rope_config,
809+
params.configs.use_logn_attn,
810+
scale_out_ptr,
811+
int8_mode,
812+
false,
813+
store_qkv,
814+
false,
815+
store_q,
816+
store_kv,
817+
store_cache,
818+
stream_);
817819
check_cuda_error();
818820
}
819821
writeCacheStore(params);

0 commit comments

Comments
 (0)