mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 05:42:31 +08:00
dinov3 fixes + other
This commit is contained in:
parent
2eef826def
commit
f4059c189e
@ -24,7 +24,6 @@ class DINOv3ViTAttention(nn.Module):
|
|||||||
self.embed_dim = hidden_size
|
self.embed_dim = hidden_size
|
||||||
self.num_heads = num_attention_heads
|
self.num_heads = num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_heads
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
self.is_causal = False
|
|
||||||
|
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.is_causal = False
|
self.is_causal = False
|
||||||
@ -53,18 +52,41 @@ class DINOv3ViTAttention(nn.Module):
|
|||||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = position_embeddings
|
if position_embeddings is not None:
|
||||||
position_embeddings = torch.stack([cos, sin], dim = -1)
|
cos, sin = position_embeddings
|
||||||
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
|
||||||
|
|
||||||
attn_output, attn_weights = optimized_attention_for_device(
|
num_tokens = query_states.shape[-2]
|
||||||
query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True
|
num_patches = cos.shape[-2]
|
||||||
|
num_prefix_tokens = num_tokens - num_patches
|
||||||
|
|
||||||
|
q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
|
||||||
|
cos = cos[..., :self.head_dim // 2]
|
||||||
|
sin = sin[..., :self.head_dim // 2]
|
||||||
|
|
||||||
|
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||||
|
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||||
|
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||||
|
|
||||||
|
while freqs_cis.ndim < q_patches.ndim + 1:
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(0)
|
||||||
|
|
||||||
|
q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis)
|
||||||
|
|
||||||
|
query_states = torch.cat((q_prefix, q_patches), dim=-2)
|
||||||
|
key_states = torch.cat((k_prefix, k_patches), dim=-2)
|
||||||
|
|
||||||
|
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||||
|
|
||||||
|
attn_output = attn(
|
||||||
|
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output
|
||||||
|
|
||||||
class DINOv3ViTGatedMLP(nn.Module):
|
class DINOv3ViTGatedMLP(nn.Module):
|
||||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||||
@ -187,7 +209,7 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm1(hidden_states)
|
hidden_states = self.norm1(hidden_states)
|
||||||
hidden_states, _ = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
@ -250,6 +272,7 @@ class DINOv3ViTModel(nn.Module):
|
|||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.norm = self.norm.to(hidden_states.device)
|
||||||
sequence_output = self.norm(hidden_states)
|
sequence_output = self.norm(hidden_states)
|
||||||
pooled_output = sequence_output[:, 0, :]
|
pooled_output = sequence_output[:, 0, :]
|
||||||
|
|
||||||
|
|||||||
@ -113,6 +113,13 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
|||||||
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||||
return q.replace(q_feats), k.replace(k_feats)
|
return q.replace(q_feats), k.replace(k_feats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||||
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
x_rotated = x_complex * phases.unsqueeze(-2)
|
||||||
|
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||||
|
return x_embed
|
||||||
|
|
||||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||||
@ -559,6 +566,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
if self._type == "self":
|
if self._type == "self":
|
||||||
|
x = x.to(next(self.to_qkv.parameters()).dtype)
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||||
|
|
||||||
@ -688,7 +696,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
num_head_channels: Optional[int] = 64,
|
num_head_channels: Optional[int] = 64,
|
||||||
mlp_ratio: float = 4,
|
mlp_ratio: float = 4,
|
||||||
pe_mode: Literal["ape", "rope"] = "ape",
|
pe_mode: Literal["ape", "rope"] = "rope",
|
||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||||
dtype: str = 'float32',
|
dtype: str = 'float32',
|
||||||
use_checkpoint: bool = False,
|
use_checkpoint: bool = False,
|
||||||
@ -756,14 +764,14 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
self.out_layer = nn.Linear(model_channels, out_channels)
|
self.out_layer = nn.Linear(model_channels, out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||||
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
||||||
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
||||||
|
|
||||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
h = h.to(next(self.input_layer.parameters()).dtype)
|
||||||
h = self.input_layer(h)
|
h = self.input_layer(h)
|
||||||
if self.pe_mode == "ape":
|
|
||||||
h = h + self.pos_emb[None]
|
|
||||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||||
if self.share_mod:
|
if self.share_mod:
|
||||||
t_emb = self.adaLN_modulation(t_emb)
|
t_emb = self.adaLN_modulation(t_emb)
|
||||||
@ -816,7 +824,8 @@ class Trellis2(nn.Module):
|
|||||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||||
|
|
||||||
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
||||||
x = x.tensors[0]
|
if isinstance(x, NestedTensor):
|
||||||
|
x = x.tensors[0]
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
if not hasattr(x, "feats"):
|
if not hasattr(x, "feats"):
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
|
|||||||
@ -97,13 +97,13 @@ def run_conditioning(
|
|||||||
return image.to(torch_device).float()
|
return image.to(torch_device).float()
|
||||||
|
|
||||||
pil_image = set_image_size(pil_image, 512)
|
pil_image = set_image_size(pil_image, 512)
|
||||||
cond_512 = model(pil_image)
|
cond_512 = model(pil_image)[0]
|
||||||
|
|
||||||
cond_1024 = None
|
cond_1024 = None
|
||||||
if include_1024:
|
if include_1024:
|
||||||
model.image_size = 1024
|
model.image_size = 1024
|
||||||
pil_image = set_image_size(pil_image, 1024)
|
pil_image = set_image_size(pil_image, 1024)
|
||||||
cond_1024 = model([pil_image])
|
cond_1024 = model(pil_image)[0]
|
||||||
|
|
||||||
neg_cond = torch.zeros_like(cond_512)
|
neg_cond = torch.zeros_like(cond_512)
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ def run_conditioning(
|
|||||||
conditioning['cond_1024'] = cond_1024.to(device)
|
conditioning['cond_1024'] = cond_1024.to(device)
|
||||||
|
|
||||||
preprocessed_tensor = pil_image.to(torch.float32) / 255.0
|
preprocessed_tensor = pil_image.to(torch.float32) / 255.0
|
||||||
preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0)
|
preprocessed_tensor = preprocessed_tensor.unsqueeze(0)
|
||||||
|
|
||||||
return conditioning, preprocessed_tensor
|
return conditioning, preprocessed_tensor
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color)
|
||||||
embeds = conditioning["cond_1024"] # should add that
|
embeds = conditioning["cond_1024"] # should add that
|
||||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
negative = [[conditioning["cond_neg"], {"embeds": embeds}]]
|
negative = [[conditioning["neg_cond"], {"embeds": embeds}]]
|
||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||||
@ -272,7 +272,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyStructureLatentTrellis2",
|
node_id="EmptyStructureLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Int.Input("resolution", default=256, min=1, max=8192),
|
|
||||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -280,8 +279,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, resolution, batch_size):
|
def execute(cls, batch_size):
|
||||||
in_channels = 32
|
in_channels = 8
|
||||||
|
resolution = 16
|
||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
latent = NestedTensor([latent])
|
latent = NestedTensor([latent])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user