From 64a52f5585d46a3e804567640da2a9627f48257e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:35:33 +0200 Subject: [PATCH] checkpoint --- comfy/ldm/trellis2/model.py | 4 +++- comfy/ldm/trellis2/vae.py | 23 ++++++++++++----------- comfy/model_detection.py | 14 ++++++++++++++ comfy/sd.py | 7 +++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 2367fc42c..8ca112b13 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -797,6 +797,7 @@ class Trellis2(nn.Module): share_mod = True, qk_rms_norm = True, qk_rms_norm_cross = True, + init_txt_model=False, # for now dtype=None, device=None, operations=None): super().__init__() @@ -806,7 +807,8 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) - self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + if init_txt_model: + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2bbfa938c..d997bbc41 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1481,17 +1481,18 @@ class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() operations = operations or torch.nn - - self.txt_dec = SparseUnetVaeDecoder( - out_channels=6, - model_channels=[1024, 512, 256, 128, 64], - latent_channels=32, - num_blocks=[4, 16, 8, 4, 0], - block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockC2S3d"] * 4, - block_args=[{}, {}, {}, {}, {}], - pred_subdiv=False - ) + init_txt_model = config.get("init_txt_model", False) + if init_txt_model: + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + pred_subdiv=False + ) self.shape_dec = FlexiDualGridVaeDecoder( resolution=256, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..4f5542af5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -106,6 +106,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config + if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config = {} + unet_config["image_model"] = "trellis2" + if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config["init_txt_model"] = True + else: + unet_config["init_txt_model"] = False + if metadata is not None: + if metadata["is_512"] is True: + unet_config["resolution"] = 32 + else: + unet_config["resolution"] = 64 + return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..be3d1b4f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.trellis2.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline @@ -492,6 +493,12 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + if "txt_dec.blocks.1.16.norm1.weight" in sd: + config["init_txt_model"] = True + else: + config["init_txt_model"] = False + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}