@@ -778,42 +778,44 @@ AttentionModuleOutput ROCmDevice::contextAttention(const AttentionModuleParams&
778778 }
779779 check_cuda_error ();
780780 } else {
781- DISPATCH_CUDA_FUNCTION_DATA_TYPE (datatype,
782- invokeAddFusedQKVBiasTranspose,
783- nullptr ,
784- q_output->data (),
785- k_output->data (),
786- v_output->data (),
787- &prefix_prompt_param,
788- params.input .data (),
781+ DISPATCH_CUDA_FUNCTION_DATA_TYPE (
782+ datatype,
783+ invokeAddFusedQKVBiasTranspose,
784+ nullptr ,
785+ q_output->data (),
786+ k_output->data (),
787+ v_output->data (),
788+ &prefix_prompt_param,
789+ params.input .data (),
790+ nullptr ,
791+ params.common .position_ids ? params.common .position_ids ->dataWithOffset <int >(
792+ decoder_batch_size * params.configs .rope_config .index_factor ) :
789793 nullptr ,
790- params.common .position_ids ?
791- params.common .position_ids ->dataWithOffset <int >(
792- decoder_batch_size * params.configs .rope_config .index_factor ) :
793- nullptr ,
794- params.configs .fuse_qkv_add_bias && params.weights .qkv_weight ->bias ?
795- params.weights .qkv_weight ->bias ->data () :
796- nullptr ,
797- params.common .padding_offset ->data <int >(),
798- params.common .cu_seqlens ->data <int >(),
799- params.common .cu_seqlens_without_prefix ->data <int >(),
800- batch_size,
801- seq_len,
802- token_num,
803- head_num,
804- kv_head_num,
805- size_per_head,
806- params.configs .rope_config ,
807- params.configs .use_logn_attn ,
808- scale_out_ptr,
809- int8_mode,
810- false ,
811- store_qkv,
812- false ,
813- store_q,
814- store_kv,
815- store_cache,
816- stream_);
794+ params.configs .fuse_qkv_add_bias && params.weights .qkv_weight ->bias ?
795+ params.weights .qkv_weight ->bias ->data () :
796+ nullptr ,
797+ params.common .padding_offset ->data <int >(),
798+ params.common .cu_seqlens ->data <int >(),
799+ params.common .cu_seqlens_without_prefix ->data <int >(),
800+ use_rope_cache_,
801+ use_rope_cache_ && rope_cache_.defined () ? rope_cache_.data_ptr <float >() : nullptr ,
802+ batch_size,
803+ seq_len,
804+ token_num,
805+ head_num,
806+ kv_head_num,
807+ size_per_head,
808+ params.configs .rope_config ,
809+ params.configs .use_logn_attn ,
810+ scale_out_ptr,
811+ int8_mode,
812+ false ,
813+ store_qkv,
814+ false ,
815+ store_q,
816+ store_kv,
817+ store_cache,
818+ stream_);
817819 check_cuda_error ();
818820 }
819821 writeCacheStore (params);
0 commit comments