checkpoint

This commit is contained in:
Yousef Rafat 2026-02-06 20:35:33 +02:00
parent cdd7ced1e8
commit 64a52f5585
4 changed files with 36 additions and 12 deletions

View File

@ -797,6 +797,7 @@ class Trellis2(nn.Module):
share_mod = True,
qk_rms_norm = True,
qk_rms_norm_cross = True,
init_txt_model=False, # for now
dtype=None, device=None, operations=None):
super().__init__()
@ -806,7 +807,8 @@ class Trellis2(nn.Module):
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)

View File

@ -1481,17 +1481,18 @@ class Vae(nn.Module):
def __init__(self, config, operations=None):
super().__init__()
operations = operations or torch.nn
self.txt_dec = SparseUnetVaeDecoder(
out_channels=6,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False
)
init_txt_model = config.get("init_txt_model", False)
if init_txt_model:
self.txt_dec = SparseUnetVaeDecoder(
out_channels=6,
model_channels=[1024, 512, 256, 128, 64],
latent_channels=32,
num_blocks=[4, 16, 8, 4, 0],
block_type=["SparseConvNeXtBlock3d"] * 5,
up_block_type=["SparseResBlockC2S3d"] * 4,
block_args=[{}, {}, {}, {}, {}],
pred_subdiv=False
)
self.shape_dec = FlexiDualGridVaeDecoder(
resolution=256,

View File

@ -106,6 +106,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config = {}
unet_config["image_model"] = "trellis2"
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config["init_txt_model"] = True
else:
unet_config["init_txt_model"] = False
if metadata is not None:
if metadata["is_512"] is True:
unet_config["resolution"] = 32
else:
unet_config["resolution"] = 64
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"

View File

@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.trellis2.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
@ -492,6 +493,12 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
if "txt_dec.blocks.1.16.norm1.weight" in sd:
config["init_txt_model"] = True
else:
config["init_txt_model"] = False
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}