fix flash attention
This commit is contained in:
parent
e97599635d
commit
a63aff0377
|
|
@ -276,7 +276,7 @@ def pay_attention(
|
|||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q= cu_seqlens_q,
|
||||
cu_seqlens_kv= cu_seqlens_k,
|
||||
cu_seqlens_k= cu_seqlens_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
max_seqlen_q=lq,
|
||||
|
|
@ -289,8 +289,8 @@ def pay_attention(
|
|||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q= [0, lq],
|
||||
cu_seqlens_kv=[0, lk],
|
||||
cu_seqlens_q= cu_seqlens_q,
|
||||
cu_seqlens_k= cu_seqlens_k,
|
||||
max_seqlen_q=lq,
|
||||
max_seqlen_k=lk,
|
||||
dropout_p=dropout_p,
|
||||
|
|
|
|||
Loading…
Reference in New Issue