Update modeling_dots_vision.py (#19)

- Update modeling_dots_vision.py (a455edb59a10ff47d298fd2ab6b3fcf53417b42a)


Co-authored-by: chen.jian <chenj123@users.noreply.huggingface.co>
This commit is contained in:
Cherrytest 2025-08-20 16:56:13 +00:00
parent 0c0b12f1db
commit 2464a795a5

View File

@ -4,11 +4,24 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
flash_attn_available = True
npu_available = True
try:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_available = False
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from .configuration_dots import DotsVisionConfig from .configuration_dots import DotsVisionConfig
try:
import torch_npu
except ImportError:
npu_available = False
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
@ -154,6 +167,89 @@ class VisionFlashAttention2(nn.Module):
return attn_output return attn_output
class VisionAttentionV2(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
self.proj = nn.Linear(dim, dim, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
seqlens = torch.diff(cu_seqlens).tolist()
q_list = torch.split(q, seqlens, 0)
k_list = torch.split(k, seqlens, 0)
v_list = torch.split(v, seqlens, 0)
# eager attention 空间复杂度为 O(n^2) , n 为 b*sbatch_size * seq_len, 序列太长容易OOM 这个实现 更具batch 切分 seq
# 减少内存需求, 计算相对 continus batching 较慢。
outputs = []
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
q_i = q_i.transpose(0, 1)
k_i = k_i.transpose(0, 1)
v_i = v_i.transpose(0, 1)
out = torch.matmul(q_i, k_i.transpose(1, 2)) / math.sqrt(self.head_dim)
out = nn.functional.softmax(out, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(out, v_i)
out = out.transpose(0, 1)
outputs.append(out)
attn_output = torch.concat(outputs, dim=0)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionAscendAttention(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
self.proj = nn.Linear(dim, dim, bias=bias)
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.ones([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = False
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_output = torch_npu.npu_prompt_flash_attention(q, k, v,
atten_mask=attention_mask,
num_heads=self.num_heads, input_layout="BNSD",
scale_value=self.head_dim ** -0.5)
attn_output = attn_output.squeeze(0).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionSdpaAttention(nn.Module): class VisionSdpaAttention(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None: def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__() super().__init__()
@ -192,8 +288,10 @@ class VisionSdpaAttention(nn.Module):
DOTS_VISION_ATTENTION_CLASSES = { DOTS_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention, "eager": VisionAttention,
"eager_v2": VisionAttentionV2, # 内存更少
"flash_attention_2": VisionFlashAttention2, "flash_attention_2": VisionFlashAttention2,
"sdpa": VisionSdpaAttention, "sdpa": VisionSdpaAttention,
"ascend_fa": VisionAscendAttention, # ascend 长序列精度下降严重。
} }
@ -231,7 +329,6 @@ class DotsSwiGLUFFN(nn.Module):
return x return x
class DotsPatchEmbed(nn.Module): class DotsPatchEmbed(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -272,6 +369,16 @@ class DotsViTPreprocessor(nn.Module):
class DotsVisionBlock(nn.Module): class DotsVisionBlock(nn.Module):
def __init__(self, config, attn_implementation: str = "flash_attention_2"): def __init__(self, config, attn_implementation: str = "flash_attention_2"):
super().__init__() super().__init__()
if attn_implementation == "flash_attention_2" and not flash_attn_available:
# fallback to eager
attn_implementation = "eager"
print("flash attention not available! fallback to eager implementation ")
if attn_implementation == "ascend_fa" and not npu_available:
attn_implementation = "eager"
print("flash attention not available! fallback to eager implementation ")
self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation]( self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
) )