mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
checkpoint
This commit is contained in:
parent
cdd7ced1e8
commit
64a52f5585
@ -797,6 +797,7 @@ class Trellis2(nn.Module):
|
||||
share_mod = True,
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
init_txt_model=False, # for now
|
||||
dtype=None, device=None, operations=None):
|
||||
|
||||
super().__init__()
|
||||
@ -806,7 +807,8 @@ class Trellis2(nn.Module):
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
if init_txt_model:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
|
||||
|
||||
@ -1481,17 +1481,18 @@ class Vae(nn.Module):
|
||||
def __init__(self, config, operations=None):
|
||||
super().__init__()
|
||||
operations = operations or torch.nn
|
||||
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
out_channels=6,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
pred_subdiv=False
|
||||
)
|
||||
init_txt_model = config.get("init_txt_model", False)
|
||||
if init_txt_model:
|
||||
self.txt_dec = SparseUnetVaeDecoder(
|
||||
out_channels=6,
|
||||
model_channels=[1024, 512, 256, 128, 64],
|
||||
latent_channels=32,
|
||||
num_blocks=[4, 16, 8, 4, 0],
|
||||
block_type=["SparseConvNeXtBlock3d"] * 5,
|
||||
up_block_type=["SparseResBlockC2S3d"] * 4,
|
||||
block_args=[{}, {}, {}, {}, {}],
|
||||
pred_subdiv=False
|
||||
)
|
||||
|
||||
self.shape_dec = FlexiDualGridVaeDecoder(
|
||||
resolution=256,
|
||||
|
||||
@ -106,6 +106,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
else:
|
||||
unet_config["init_txt_model"] = False
|
||||
if metadata is not None:
|
||||
if metadata["is_512"] is True:
|
||||
unet_config["resolution"] = 32
|
||||
else:
|
||||
unet_config["resolution"] = 64
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
|
||||
@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.trellis2.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
@ -492,6 +493,12 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
config["init_txt_model"] = True
|
||||
else:
|
||||
config["init_txt_model"] = False
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user