diff --git a/modeling_dots_vision.py b/modeling_dots_vision.py index 00c5cd8..1046513 100644 --- a/modeling_dots_vision.py +++ b/modeling_dots_vision.py @@ -4,16 +4,29 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from flash_attn import flash_attn_varlen_func + +flash_attn_available = True +npu_available = True + +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + flash_attn_available = False + from torch.nn import LayerNorm from transformers.modeling_utils import PreTrainedModel from .configuration_dots import DotsVisionConfig +try: + import torch_npu +except ImportError: + npu_available = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -48,15 +61,15 @@ class VisionRotaryEmbedding(nn.Module): class PatchMerger(nn.Module): def __init__( - self, - dim: int, - context_dim: int, - spatial_merge_size: int = 2, - pre_norm="layernorm", - init_merger_std=None, + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + pre_norm="layernorm", + init_merger_std=None, ) -> None: super().__init__() - self.hidden_size = context_dim * (spatial_merge_size**2) + self.hidden_size = context_dim * (spatial_merge_size ** 2) self.pre_norm = pre_norm if self.pre_norm == "layernorm": self.ln_q = LayerNorm(context_dim, eps=1e-6) @@ -94,10 +107,10 @@ class VisionAttention(nn.Module): self.proj = nn.Linear(dim, dim, bias=bias) def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -109,7 +122,7 @@ class VisionAttention(nn.Module): [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0 q = q.transpose(0, 1) k = k.transpose(0, 1) @@ -134,10 +147,10 @@ class VisionFlashAttention2(nn.Module): self.is_causal = config.is_causal def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = ( @@ -154,6 +167,89 @@ class VisionFlashAttention2(nn.Module): return attn_output +class VisionAttentionV2(nn.Module): + def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=bias) + self.proj = nn.Linear(dim, dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + seqlens = torch.diff(cu_seqlens).tolist() + + q_list = torch.split(q, seqlens, 0) + k_list = torch.split(k, seqlens, 0) + v_list = torch.split(v, seqlens, 0) + # eager attention 空间复杂度为 O(n^2) , n 为 b*s(batch_size * seq_len), 序列太长容易OOM, 这个实现 更具batch 切分 seq + # 减少内存需求, 计算相对 continus batching 较慢。 + outputs = [] + for q_i, k_i, v_i in zip(q_list, k_list, v_list): + q_i = q_i.transpose(0, 1) + k_i = k_i.transpose(0, 1) + v_i = v_i.transpose(0, 1) + out = torch.matmul(q_i, k_i.transpose(1, 2)) / math.sqrt(self.head_dim) + out = nn.functional.softmax(out, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(out, v_i) + out = out.transpose(0, 1) + outputs.append(out) + + attn_output = torch.concat(outputs, dim=0) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class VisionAscendAttention(nn.Module): + def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=bias) + self.proj = nn.Linear(dim, dim, bias=bias) + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.ones([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = False + + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + attn_output = torch_npu.npu_prompt_flash_attention(q, k, v, + atten_mask=attention_mask, + num_heads=self.num_heads, input_layout="BNSD", + scale_value=self.head_dim ** -0.5) + attn_output = attn_output.squeeze(0).transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + class VisionSdpaAttention(nn.Module): def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None: super().__init__() @@ -163,10 +259,10 @@ class VisionSdpaAttention(nn.Module): self.config = config def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) @@ -176,7 +272,7 @@ class VisionSdpaAttention(nn.Module): attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True q = q.transpose(0, 1) k = k.transpose(0, 1) @@ -192,8 +288,10 @@ class VisionSdpaAttention(nn.Module): DOTS_VISION_ATTENTION_CLASSES = { "eager": VisionAttention, + "eager_v2": VisionAttentionV2, # 内存更少 "flash_attention_2": VisionFlashAttention2, "sdpa": VisionSdpaAttention, + "ascend_fa": VisionAscendAttention, # ascend, 长序列精度下降严重。 } @@ -231,7 +329,6 @@ class DotsSwiGLUFFN(nn.Module): return x - class DotsPatchEmbed(nn.Module): def __init__(self, config): super().__init__() @@ -249,7 +346,7 @@ class DotsPatchEmbed(nn.Module): self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: - x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0] + x = x.view(-1, self.num_channels, self.temporal_patch_size, self.patch_size, self.patch_size)[:, :, 0] x = self.proj(x).view(-1, self.embed_dim) x = self.norm(x) return x @@ -272,6 +369,16 @@ class DotsViTPreprocessor(nn.Module): class DotsVisionBlock(nn.Module): def __init__(self, config, attn_implementation: str = "flash_attention_2"): super().__init__() + + if attn_implementation == "flash_attention_2" and not flash_attn_available: + # fallback to eager + attn_implementation = "eager" + print("flash attention not available! fallback to eager implementation ") + + if attn_implementation == "ascend_fa" and not npu_available: + attn_implementation = "eager" + print("flash attention not available! fallback to eager implementation ") + self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation]( config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias ) @@ -401,4 +508,4 @@ class DotsVisionTransformer(PreTrainedModel): hidden_states = self.post_trunk_norm(hidden_states) hidden_states = self.merger(hidden_states) - return hidden_states \ No newline at end of file + return hidden_states