Update modeling_dots_vision.py (#19)

- Update modeling_dots_vision.py (a455edb59a10ff47d298fd2ab6b3fcf53417b42a)


Co-authored-by: chen.jian <chenj123@users.noreply.huggingface.co>
This commit is contained in:
Cherrytest 2025-08-20 16:56:13 +00:00
parent 0c0b12f1db
commit 2464a795a5

View File

@ -4,16 +4,29 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from flash_attn import flash_attn_varlen_func
flash_attn_available = True
npu_available = True
try:
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_available = False
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from .configuration_dots import DotsVisionConfig from .configuration_dots import DotsVisionConfig
try:
import torch_npu
except ImportError:
npu_available = False
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :] x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
@ -48,15 +61,15 @@ class VisionRotaryEmbedding(nn.Module):
class PatchMerger(nn.Module): class PatchMerger(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
context_dim: int, context_dim: int,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
pre_norm="layernorm", pre_norm="layernorm",
init_merger_std=None, init_merger_std=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size ** 2)
self.pre_norm = pre_norm self.pre_norm = pre_norm
if self.pre_norm == "layernorm": if self.pre_norm == "layernorm":
self.ln_q = LayerNorm(context_dim, eps=1e-6) self.ln_q = LayerNorm(context_dim, eps=1e-6)
@ -94,10 +107,10 @@ class VisionAttention(nn.Module):
self.proj = nn.Linear(dim, dim, bias=bias) self.proj = nn.Linear(dim, dim, bias=bias)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None, rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
@ -109,7 +122,7 @@ class VisionAttention(nn.Module):
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
) )
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
q = q.transpose(0, 1) q = q.transpose(0, 1)
k = k.transpose(0, 1) k = k.transpose(0, 1)
@ -134,10 +147,10 @@ class VisionFlashAttention2(nn.Module):
self.is_causal = config.is_causal self.is_causal = config.is_causal
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None, rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
q, k, v = ( q, k, v = (
@ -154,6 +167,89 @@ class VisionFlashAttention2(nn.Module):
return attn_output return attn_output
class VisionAttentionV2(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
self.proj = nn.Linear(dim, dim, bias=bias)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
seqlens = torch.diff(cu_seqlens).tolist()
q_list = torch.split(q, seqlens, 0)
k_list = torch.split(k, seqlens, 0)
v_list = torch.split(v, seqlens, 0)
# eager attention 空间复杂度为 O(n^2) , n 为 b*sbatch_size * seq_len, 序列太长容易OOM 这个实现 更具batch 切分 seq
# 减少内存需求, 计算相对 continus batching 较慢。
outputs = []
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
q_i = q_i.transpose(0, 1)
k_i = k_i.transpose(0, 1)
v_i = v_i.transpose(0, 1)
out = torch.matmul(q_i, k_i.transpose(1, 2)) / math.sqrt(self.head_dim)
out = nn.functional.softmax(out, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(out, v_i)
out = out.transpose(0, 1)
outputs.append(out)
attn_output = torch.concat(outputs, dim=0)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionAscendAttention(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
self.proj = nn.Linear(dim, dim, bias=bias)
self.config = config
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.ones([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = False
q = q.transpose(0, 1).unsqueeze(0)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
attn_output = torch_npu.npu_prompt_flash_attention(q, k, v,
atten_mask=attention_mask,
num_heads=self.num_heads, input_layout="BNSD",
scale_value=self.head_dim ** -0.5)
attn_output = attn_output.squeeze(0).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionSdpaAttention(nn.Module): class VisionSdpaAttention(nn.Module):
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None: def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
super().__init__() super().__init__()
@ -163,10 +259,10 @@ class VisionSdpaAttention(nn.Module):
self.config = config self.config = config
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None, rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_length = hidden_states.shape[0] seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
@ -176,7 +272,7 @@ class VisionSdpaAttention(nn.Module):
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
q = q.transpose(0, 1) q = q.transpose(0, 1)
k = k.transpose(0, 1) k = k.transpose(0, 1)
@ -192,8 +288,10 @@ class VisionSdpaAttention(nn.Module):
DOTS_VISION_ATTENTION_CLASSES = { DOTS_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention, "eager": VisionAttention,
"eager_v2": VisionAttentionV2, # 内存更少
"flash_attention_2": VisionFlashAttention2, "flash_attention_2": VisionFlashAttention2,
"sdpa": VisionSdpaAttention, "sdpa": VisionSdpaAttention,
"ascend_fa": VisionAscendAttention, # ascend 长序列精度下降严重。
} }
@ -231,7 +329,6 @@ class DotsSwiGLUFFN(nn.Module):
return x return x
class DotsPatchEmbed(nn.Module): class DotsPatchEmbed(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -249,7 +346,7 @@ class DotsPatchEmbed(nn.Module):
self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0] x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0]
x = self.proj(x).view(-1, self.embed_dim) x = self.proj(x).view(-1, self.embed_dim)
x = self.norm(x) x = self.norm(x)
return x return x
@ -272,6 +369,16 @@ class DotsViTPreprocessor(nn.Module):
class DotsVisionBlock(nn.Module): class DotsVisionBlock(nn.Module):
def __init__(self, config, attn_implementation: str = "flash_attention_2"): def __init__(self, config, attn_implementation: str = "flash_attention_2"):
super().__init__() super().__init__()
if attn_implementation == "flash_attention_2" and not flash_attn_available:
# fallback to eager
attn_implementation = "eager"
print("flash attention not available! fallback to eager implementation ")
if attn_implementation == "ascend_fa" and not npu_available:
attn_implementation = "eager"
print("flash attention not available! fallback to eager implementation ")
self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation]( self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
) )
@ -401,4 +508,4 @@ class DotsVisionTransformer(PreTrainedModel):
hidden_states = self.post_trunk_norm(hidden_states) hidden_states = self.post_trunk_norm(hidden_states)
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
return hidden_states return hidden_states