diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 1bf404498..ff17d78d6 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() - if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): + use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True) + if dtype == torch.float16 and use_bf16: dtype = torch.bfloat16 + elif dtype == torch.float16 and not use_bf16: + dtype = torch.float32 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 19de93b96..0b9c12294 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -10,8 +10,8 @@ import comfy.ops def var_attn_arg(kwargs): cu_seqlens_q = kwargs.get("cu_seqlens_q", None) max_seqlen_q = kwargs.get("max_seqlen_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) + cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q) assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @@ -183,6 +183,7 @@ def calc_window_partition( def sparse_scaled_dot_product_attention(*args, **kwargs): + q=None arg_names_dict = { 1: ['qkv'], 2: ['q', 'kv'], @@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + # TODO: change + if q is not None: + heads = q + else: + heads = qkv + heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': if 'xops' not in globals(): import xformers.ops as xops @@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) if num_all_args in [2, 3]: cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + else: + cu_seqlens_kv = cu_seqlens_q if num_all_args == 1: q, k, v = qkv.unbind(dim=1) elif num_all_args == 2: k, v = kv.unbind(dim=1) - out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen), + skip_reshape=True, skip_output_reshape=True) if s is not None: return s.replace(out) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 52242b4e0..a565ec37e 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module): else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) + dtype = next(self.to_kv.parameters()).dtype + context = context.to(dtype) kv = self._linear(self.to_kv, context) kv = self._fused_pre(kv, num_fused=2) if self.qk_rms_norm: @@ -760,15 +762,13 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): - # FIXME: should find a way to distinguish between 512/1024 models - # currently assumes 1024 transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - #_, cond = context.chunk(2) # TODO - cond = embeds.chunk(2)[0] - context = torch.cat([torch.zeros_like(cond), cond]) + is_1024 = self.img2shape.resolution == 1024 + if is_1024: + context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2b44a19eb..1b43f7f62 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch.nn.functional as TF import comfy.model_management +from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + bar.update(1) + else: + bar.update(1) if simplify != 0: verts, faces = simplify_fn(verts, faces, simplify) + bar.update(1) + else: + bar.update(1) + # potentially adding laplacian smoothing mesh.vertices = verts mesh.faces = faces