mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Compare commits
1 Commits
d7e58f53d3
...
3ee6571cbd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ee6571cbd |
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, Tuple
|
||||
from typing import Optional, Any
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
@ -32,7 +32,6 @@ class Llama2Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
@ -55,7 +54,6 @@ class Mistral3Small24BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@ -78,7 +76,6 @@ class Qwen25_3BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06BConfig:
|
||||
@ -101,7 +98,6 @@ class Qwen3_06BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
@ -124,7 +120,6 @@ class Qwen3_4BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_8BConfig:
|
||||
@ -147,7 +142,6 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
@ -170,7 +164,6 @@ class Ovis25_2BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@ -193,7 +186,6 @@ class Qwen25_7BVLI_Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -217,7 +209,6 @@ class Gemma2_2B_Config:
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@ -241,7 +232,6 @@ class Gemma3_4B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
@ -265,7 +255,6 @@ class Gemma3_12B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@ -367,7 +356,6 @@ class Attention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
xq = self.q_proj(hidden_states)
|
||||
@ -385,30 +373,11 @@ class Attention(nn.Module):
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
index = 0
|
||||
num_tokens = xk.shape[2]
|
||||
if len(past_key_value) > 0:
|
||||
past_key, past_value, index = past_key_value
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + xk.shape[2]] = xk
|
||||
past_value[:, :, index:index + xv.shape[2]] = xv
|
||||
xk = past_key[:, :, :index + xk.shape[2]]
|
||||
xv = past_value[:, :, :index + xv.shape[2]]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output), present_key_value
|
||||
return self.o_proj(output)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@ -439,17 +408,15 @@ class TransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x, present_key_value = self.self_attn(
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
@ -459,7 +426,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
|
||||
return x, present_key_value
|
||||
return x
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
@ -484,7 +451,6 @@ class TransformerBlockGemma2(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
@ -502,12 +468,11 @@ class TransformerBlockGemma2(nn.Module):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x, present_key_value = self.self_attn(
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@ -520,7 +485,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x, present_key_value
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
@ -551,10 +516,9 @@ class Llama2_(nn.Module):
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -563,13 +527,8 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
@ -580,16 +539,14 @@ class Llama2_(nn.Module):
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
@ -605,27 +562,16 @@ class Llama2_(nn.Module):
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
next_key_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
past_kv = None
|
||||
if past_key_values is not None:
|
||||
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
|
||||
|
||||
x, current_kv = layer(
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@ -642,10 +588,7 @@ class Llama2_(nn.Module):
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
if len(next_key_values) > 0:
|
||||
return x, intermediate, next_key_values
|
||||
else:
|
||||
return x, intermediate
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
|
||||
@ -1248,7 +1248,6 @@ class Hidden(str, Enum):
|
||||
class NodeInfoV1:
|
||||
input: dict=None
|
||||
input_order: dict[str, list[str]]=None
|
||||
is_input_list: bool=None
|
||||
output: list[str]=None
|
||||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
@ -1475,7 +1474,6 @@ class Schema:
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
is_input_list=self.is_input_list,
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
|
||||
@ -656,7 +656,6 @@ class PromptServer():
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False)
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
|
||||
Loading…
Reference in New Issue
Block a user