Update VisionSdpaAttention to support memory efficient backend. (#27)
- Update VisionSdpaAttention to support memory efficient backend. (fc8b0b11b92c381639e616506cca574f1b05af09) Co-authored-by: warren wang <warrenwjk@users.noreply.huggingface.co>
This commit is contained in:
parent
2464a795a5
commit
1f0c298e18
@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
||||||
|
|
||||||
q = q.transpose(0, 1)
|
# Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim)
|
||||||
k = k.transpose(0, 1)
|
q = q.transpose(0, 1).unsqueeze(0) # (1, num_heads, seq_length, head_dim)
|
||||||
v = v.transpose(0, 1)
|
k = k.transpose(0, 1).unsqueeze(0)
|
||||||
|
v = v.transpose(0, 1).unsqueeze(0)
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
# See: https://github.com/pytorch/pytorch/issues/127523
|
||||||
attn_output = attn_output.transpose(0, 1)
|
if attention_mask.stride(-1) != 1:
|
||||||
|
attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask)
|
||||||
|
|
||||||
|
# use memory efficient backend
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||||
|
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||||
|
|
||||||
|
attn_output = attn_output.squeeze(0).transpose(0, 1) # (seq_length, num_heads, head_dim)
|
||||||
attn_output = attn_output.reshape(seq_length, -1)
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
|
||||||
attn_output = self.proj(attn_output)
|
attn_output = self.proj(attn_output)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user