From 704e1b54621a78b30d529b5367a2951383bd890c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:41:01 +0200 Subject: [PATCH] small bug fixes --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 5 ++++- comfy/ldm/trellis2/vae.py | 5 ++--- comfy/model_detection.py | 12 +++++++----- comfy/sd.py | 7 +++---- comfy_extras/nodes_trellis2.py | 8 ++++---- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index edc85ce83..3038f4023 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -2,7 +2,7 @@ import torch import math from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List -from vae import VarLenTensor +from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8ca112b13..9aab045c7 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -798,9 +798,12 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, init_txt_model=False, # for now - dtype=None, device=None, operations=None): + dtype=None, device=None, operations=None, **kwargs): super().__init__() + self.dtype = dtype + # for some reason it passes num_heads = -1 + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d26986cc..6e13afd8d 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -6,7 +6,7 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -1457,10 +1457,9 @@ class SparseStructureDecoder(nn.Module): return h class Vae(nn.Module): - def __init__(self, config, operations=None): + def __init__(self, init_txt_model, operations=None): super().__init__() operations = operations or torch.nn - init_txt_model = config.get("init_txt_model", False) if init_txt_model: self.txt_dec = SparseUnetVaeDecoder( out_channels=6, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4f5542af5..004adbf71 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -109,15 +109,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config = {} unet_config["image_model"] = "trellis2" + + unet_config["init_txt_model"] = False if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True - else: - unet_config["init_txt_model"] = False + + unet_config["resolution"] = 64 if metadata is not None: - if metadata["is_512"] is True: + if "is_512" in metadata and metadata["metadata"]: unet_config["resolution"] = 32 - else: - unet_config["resolution"] = 64 + + unet_config["num_heads"] = 12 return unet_config if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit diff --git a/comfy/sd.py b/comfy/sd.py index be3d1b4f0..25fd3ba7b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -494,11 +494,10 @@ class VAE: self.downscale_ratio = 32 self.latent_channels = 16 elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: - config["init_txt_model"] = True - else: - config["init_txt_model"] = False - self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) + init_txt_model = True + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8497b83e2..17ba94ec8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -198,7 +198,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), - IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ IO.Conditioning.Output(display_name="positive"), @@ -219,7 +219,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -242,7 +242,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -264,7 +264,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192),