From 1f0c298e18b70dc27b6e7546b66e5dedcd3e736c Mon Sep 17 00:00:00 2001 From: Cherrytest Date: Sat, 13 Sep 2025 14:38:43 +0000 Subject: [PATCH] Update VisionSdpaAttention to support memory efficient backend. (#27) - Update VisionSdpaAttention to support memory efficient backend. (fc8b0b11b92c381639e616506cca574f1b05af09) Co-authored-by: warren wang --- modeling_dots_vision.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/modeling_dots_vision.py b/modeling_dots_vision.py index 1046513..42a570e 100644 --- a/modeling_dots_vision.py +++ b/modeling_dots_vision.py @@ -274,12 +274,21 @@ class VisionSdpaAttention(nn.Module): for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) + # Convert q, k, v to 4D to enable : (1, num_heads, seq_length, head_dim) + q = q.transpose(0, 1).unsqueeze(0) # (1, num_heads, seq_length, head_dim) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) + # See: https://github.com/pytorch/pytorch/issues/127523 + if attention_mask.stride(-1) != 1: + attention_mask = torch.empty_like(attention_mask, memory_format=torch.contiguous_format).copy_(attention_mask) + + # use memory efficient backend + from torch.nn.attention import SDPBackend, sdpa_kernel + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + + attn_output = attn_output.squeeze(0).transpose(0, 1) # (seq_length, num_heads, head_dim) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output)