Update modeling_dots_vision.py (#19)
- Update modeling_dots_vision.py (a455edb59a10ff47d298fd2ab6b3fcf53417b42a) Co-authored-by: chen.jian <chenj123@users.noreply.huggingface.co>
This commit is contained in:
parent
0c0b12f1db
commit
2464a795a5
@ -4,16 +4,29 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
flash_attn_available = True
|
||||
npu_available = True
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_available = False
|
||||
|
||||
from torch.nn import LayerNorm
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from .configuration_dots import DotsVisionConfig
|
||||
|
||||
try:
|
||||
import torch_npu
|
||||
except ImportError:
|
||||
npu_available = False
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@ -56,7 +69,7 @@ class PatchMerger(nn.Module):
|
||||
init_merger_std=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||
self.pre_norm = pre_norm
|
||||
if self.pre_norm == "layernorm":
|
||||
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
||||
@ -109,7 +122,7 @@ class VisionAttention(nn.Module):
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
@ -154,6 +167,89 @@ class VisionFlashAttention2(nn.Module):
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionAttentionV2(nn.Module):
|
||||
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
|
||||
self.proj = nn.Linear(dim, dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
seqlens = torch.diff(cu_seqlens).tolist()
|
||||
|
||||
q_list = torch.split(q, seqlens, 0)
|
||||
k_list = torch.split(k, seqlens, 0)
|
||||
v_list = torch.split(v, seqlens, 0)
|
||||
# eager attention 空间复杂度为 O(n^2) , n 为 b*s(batch_size * seq_len), 序列太长容易OOM, 这个实现 更具batch 切分 seq
|
||||
# 减少内存需求, 计算相对 continus batching 较慢。
|
||||
outputs = []
|
||||
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
|
||||
q_i = q_i.transpose(0, 1)
|
||||
k_i = k_i.transpose(0, 1)
|
||||
v_i = v_i.transpose(0, 1)
|
||||
out = torch.matmul(q_i, k_i.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||
out = nn.functional.softmax(out, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
out = torch.matmul(out, v_i)
|
||||
out = out.transpose(0, 1)
|
||||
outputs.append(out)
|
||||
|
||||
attn_output = torch.concat(outputs, dim=0)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionAscendAttention(nn.Module):
|
||||
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
|
||||
self.proj = nn.Linear(dim, dim, bias=bias)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.ones([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = False
|
||||
|
||||
q = q.transpose(0, 1).unsqueeze(0)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
attn_output = torch_npu.npu_prompt_flash_attention(q, k, v,
|
||||
atten_mask=attention_mask,
|
||||
num_heads=self.num_heads, input_layout="BNSD",
|
||||
scale_value=self.head_dim ** -0.5)
|
||||
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class VisionSdpaAttention(nn.Module):
|
||||
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||
super().__init__()
|
||||
@ -176,7 +272,7 @@ class VisionSdpaAttention(nn.Module):
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
||||
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
@ -192,8 +288,10 @@ class VisionSdpaAttention(nn.Module):
|
||||
|
||||
DOTS_VISION_ATTENTION_CLASSES = {
|
||||
"eager": VisionAttention,
|
||||
"eager_v2": VisionAttentionV2, # 内存更少
|
||||
"flash_attention_2": VisionFlashAttention2,
|
||||
"sdpa": VisionSdpaAttention,
|
||||
"ascend_fa": VisionAscendAttention, # ascend, 长序列精度下降严重。
|
||||
}
|
||||
|
||||
|
||||
@ -231,7 +329,6 @@ class DotsSwiGLUFFN(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class DotsPatchEmbed(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -272,6 +369,16 @@ class DotsViTPreprocessor(nn.Module):
|
||||
class DotsVisionBlock(nn.Module):
|
||||
def __init__(self, config, attn_implementation: str = "flash_attention_2"):
|
||||
super().__init__()
|
||||
|
||||
if attn_implementation == "flash_attention_2" and not flash_attn_available:
|
||||
# fallback to eager
|
||||
attn_implementation = "eager"
|
||||
print("flash attention not available! fallback to eager implementation ")
|
||||
|
||||
if attn_implementation == "ascend_fa" and not npu_available:
|
||||
attn_implementation = "eager"
|
||||
print("flash attention not available! fallback to eager implementation ")
|
||||
|
||||
self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
|
||||
config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user