fix flash attention
This commit is contained in:
parent
e97599635d
commit
a63aff0377
|
|
@ -276,7 +276,7 @@ def pay_attention(
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
cu_seqlens_q= cu_seqlens_q,
|
cu_seqlens_q= cu_seqlens_q,
|
||||||
cu_seqlens_kv= cu_seqlens_k,
|
cu_seqlens_k= cu_seqlens_k,
|
||||||
seqused_q=None,
|
seqused_q=None,
|
||||||
seqused_k=None,
|
seqused_k=None,
|
||||||
max_seqlen_q=lq,
|
max_seqlen_q=lq,
|
||||||
|
|
@ -289,8 +289,8 @@ def pay_attention(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
cu_seqlens_q= [0, lq],
|
cu_seqlens_q= cu_seqlens_q,
|
||||||
cu_seqlens_kv=[0, lk],
|
cu_seqlens_k= cu_seqlens_k,
|
||||||
max_seqlen_q=lq,
|
max_seqlen_q=lq,
|
||||||
max_seqlen_k=lk,
|
max_seqlen_k=lk,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue