diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index b27b95b5f..ef04556da 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -24,7 +24,6 @@ class DINOv3ViTAttention(nn.Module): self.embed_dim = hidden_size self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False self.scaling = self.head_dim**-0.5 self.is_causal = False @@ -53,18 +52,41 @@ class DINOv3ViTAttention(nn.Module): key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - position_embeddings = torch.stack([cos, sin], dim = -1) - query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + if position_embeddings is not None: + cos, sin = position_embeddings - attn_output, attn_weights = optimized_attention_for_device( - query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + num_tokens = query_states.shape[-2] + num_patches = cos.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) + + cos = cos[..., :self.head_dim // 2] + sin = sin[..., :self.head_dim // 2] + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + while freqs_cis.ndim < q_patches.ndim + 1: + freqs_cis = freqs_cis.unsqueeze(0) + + q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) + + query_states = torch.cat((q_prefix, q_patches), dim=-2) + key_states = torch.cat((k_prefix, k_patches), dim=-2) + + attn = optimized_attention_for_device(query_states.device, mask=False) + + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output class DINOv3ViTGatedMLP(nn.Module): def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): @@ -187,7 +209,7 @@ class DINOv3ViTLayer(nn.Module): ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( + hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, @@ -250,6 +272,7 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) + self.norm = self.norm.to(hidden_states.device) sequence_output = self.norm(hidden_states) pooled_output = sequence_output[:, 0, :] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9aab045c7..8bc8e8f7a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -113,6 +113,13 @@ class SparseRotaryPositionEmbedder(nn.Module): q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) return q.replace(q_feats), k.replace(k_feats) + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) @@ -559,6 +566,7 @@ class MultiHeadAttention(nn.Module): def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": + x = x.to(next(self.to_qkv.parameters()).dtype) qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) @@ -688,7 +696,7 @@ class SparseStructureFlowModel(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - pe_mode: Literal["ape", "rope"] = "ape", + pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), dtype: str = 'float32', use_checkpoint: bool = False, @@ -756,14 +764,14 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + h = h.to(next(self.input_layer.parameters()).dtype) h = self.input_layer(h) - if self.pe_mode == "ape": - h = h + self.pos_emb[None] t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -816,7 +824,8 @@ class Trellis2(nn.Module): self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs): - x = x.tensors[0] + if isinstance(x, NestedTensor): + x = x.tensors[0] embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f53d36736..4eff2dbc3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -97,13 +97,13 @@ def run_conditioning( return image.to(torch_device).float() pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image) + cond_512 = model(pil_image)[0] cond_1024 = None if include_1024: model.image_size = 1024 pil_image = set_image_size(pil_image, 1024) - cond_1024 = model([pil_image]) + cond_1024 = model(pil_image)[0] neg_cond = torch.zeros_like(cond_512) @@ -115,7 +115,7 @@ def run_conditioning( conditioning['cond_1024'] = cond_1024.to(device) preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + preprocessed_tensor = preprocessed_tensor.unsqueeze(0) return conditioning, preprocessed_tensor @@ -217,7 +217,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["cond_neg"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -272,7 +272,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -280,8 +279,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, resolution, batch_size): - in_channels = 32 + def execute(cls, batch_size): + in_channels = 8 + resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"})