Update modeling_dots_vision.py (#19)
- Update modeling_dots_vision.py (a455edb59a10ff47d298fd2ab6b3fcf53417b42a) Co-authored-by: chen.jian <chenj123@users.noreply.huggingface.co>
This commit is contained in:
parent
0c0b12f1db
commit
2464a795a5
@ -4,16 +4,29 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
|
flash_attn_available = True
|
||||||
|
npu_available = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
except ImportError:
|
||||||
|
flash_attn_available = False
|
||||||
|
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from .configuration_dots import DotsVisionConfig
|
from .configuration_dots import DotsVisionConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch_npu
|
||||||
|
except ImportError:
|
||||||
|
npu_available = False
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
x2 = x[..., x.shape[-1] // 2:]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@ -48,15 +61,15 @@ class VisionRotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
class PatchMerger(nn.Module):
|
class PatchMerger(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
context_dim: int,
|
context_dim: int,
|
||||||
spatial_merge_size: int = 2,
|
spatial_merge_size: int = 2,
|
||||||
pre_norm="layernorm",
|
pre_norm="layernorm",
|
||||||
init_merger_std=None,
|
init_merger_std=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
if self.pre_norm == "layernorm":
|
if self.pre_norm == "layernorm":
|
||||||
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
||||||
@ -94,10 +107,10 @@ class VisionAttention(nn.Module):
|
|||||||
self.proj = nn.Linear(dim, dim, bias=bias)
|
self.proj = nn.Linear(dim, dim, bias=bias)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor = None,
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
|
|
||||||
@ -109,7 +122,7 @@ class VisionAttention(nn.Module):
|
|||||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||||
)
|
)
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = 0
|
||||||
|
|
||||||
q = q.transpose(0, 1)
|
q = q.transpose(0, 1)
|
||||||
k = k.transpose(0, 1)
|
k = k.transpose(0, 1)
|
||||||
@ -134,10 +147,10 @@ class VisionFlashAttention2(nn.Module):
|
|||||||
self.is_causal = config.is_causal
|
self.is_causal = config.is_causal
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor = None,
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
q, k, v = (
|
q, k, v = (
|
||||||
@ -154,6 +167,89 @@ class VisionFlashAttention2(nn.Module):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionAttentionV2(nn.Module):
|
||||||
|
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_length = hidden_states.shape[0]
|
||||||
|
|
||||||
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
seqlens = torch.diff(cu_seqlens).tolist()
|
||||||
|
|
||||||
|
q_list = torch.split(q, seqlens, 0)
|
||||||
|
k_list = torch.split(k, seqlens, 0)
|
||||||
|
v_list = torch.split(v, seqlens, 0)
|
||||||
|
# eager attention 空间复杂度为 O(n^2) , n 为 b*s(batch_size * seq_len), 序列太长容易OOM, 这个实现 更具batch 切分 seq
|
||||||
|
# 减少内存需求, 计算相对 continus batching 较慢。
|
||||||
|
outputs = []
|
||||||
|
for q_i, k_i, v_i in zip(q_list, k_list, v_list):
|
||||||
|
q_i = q_i.transpose(0, 1)
|
||||||
|
k_i = k_i.transpose(0, 1)
|
||||||
|
v_i = v_i.transpose(0, 1)
|
||||||
|
out = torch.matmul(q_i, k_i.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||||
|
out = nn.functional.softmax(out, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
|
out = torch.matmul(out, v_i)
|
||||||
|
out = out.transpose(0, 1)
|
||||||
|
outputs.append(out)
|
||||||
|
|
||||||
|
attn_output = torch.concat(outputs, dim=0)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionAscendAttention(nn.Module):
|
||||||
|
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=bias)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=bias)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_length = hidden_states.shape[0]
|
||||||
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
|
||||||
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
attention_mask = torch.ones([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = False
|
||||||
|
|
||||||
|
q = q.transpose(0, 1).unsqueeze(0)
|
||||||
|
k = k.transpose(0, 1).unsqueeze(0)
|
||||||
|
v = v.transpose(0, 1).unsqueeze(0)
|
||||||
|
|
||||||
|
attn_output = torch_npu.npu_prompt_flash_attention(q, k, v,
|
||||||
|
atten_mask=attention_mask,
|
||||||
|
num_heads=self.num_heads, input_layout="BNSD",
|
||||||
|
scale_value=self.head_dim ** -0.5)
|
||||||
|
attn_output = attn_output.squeeze(0).transpose(0, 1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
class VisionSdpaAttention(nn.Module):
|
class VisionSdpaAttention(nn.Module):
|
||||||
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
def __init__(self, config, dim: int, num_heads: int = 16, bias=True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -163,10 +259,10 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
rotary_pos_emb: torch.Tensor = None,
|
rotary_pos_emb: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
seq_length = hidden_states.shape[0]
|
seq_length = hidden_states.shape[0]
|
||||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
@ -176,7 +272,7 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||||
for i in range(1, len(cu_seqlens)):
|
for i in range(1, len(cu_seqlens)):
|
||||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
attention_mask[..., cu_seqlens[i - 1]: cu_seqlens[i], cu_seqlens[i - 1]: cu_seqlens[i]] = True
|
||||||
|
|
||||||
q = q.transpose(0, 1)
|
q = q.transpose(0, 1)
|
||||||
k = k.transpose(0, 1)
|
k = k.transpose(0, 1)
|
||||||
@ -192,8 +288,10 @@ class VisionSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
DOTS_VISION_ATTENTION_CLASSES = {
|
DOTS_VISION_ATTENTION_CLASSES = {
|
||||||
"eager": VisionAttention,
|
"eager": VisionAttention,
|
||||||
|
"eager_v2": VisionAttentionV2, # 内存更少
|
||||||
"flash_attention_2": VisionFlashAttention2,
|
"flash_attention_2": VisionFlashAttention2,
|
||||||
"sdpa": VisionSdpaAttention,
|
"sdpa": VisionSdpaAttention,
|
||||||
|
"ascend_fa": VisionAscendAttention, # ascend, 长序列精度下降严重。
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -231,7 +329,6 @@ class DotsSwiGLUFFN(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DotsPatchEmbed(nn.Module):
|
class DotsPatchEmbed(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -272,6 +369,16 @@ class DotsViTPreprocessor(nn.Module):
|
|||||||
class DotsVisionBlock(nn.Module):
|
class DotsVisionBlock(nn.Module):
|
||||||
def __init__(self, config, attn_implementation: str = "flash_attention_2"):
|
def __init__(self, config, attn_implementation: str = "flash_attention_2"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if attn_implementation == "flash_attention_2" and not flash_attn_available:
|
||||||
|
# fallback to eager
|
||||||
|
attn_implementation = "eager"
|
||||||
|
print("flash attention not available! fallback to eager implementation ")
|
||||||
|
|
||||||
|
if attn_implementation == "ascend_fa" and not npu_available:
|
||||||
|
attn_implementation = "eager"
|
||||||
|
print("flash attention not available! fallback to eager implementation ")
|
||||||
|
|
||||||
self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
|
self.attn = DOTS_VISION_ATTENTION_CLASSES[attn_implementation](
|
||||||
config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
|
config, config.embed_dim, num_heads=config.num_attention_heads, bias=config.use_bias
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user