diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 14810d56d..a331cb502 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -779,15 +779,18 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): transformer_options = kwargs.get("transformer_options", {}) + model_options = {} + if hasattr(self, "meta"): + model_options = self.meta timestep = timestep.to(x.dtype) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") is_1024 = self.img2shape.resolution == 1024 - coords = transformer_options.get("coords", None) - coord_counts = transformer_options.get("coord_counts", None) - mode = transformer_options.get("generation_mode", "structure_generation") + coords = model_options.get("coords", None) + coord_counts = model_options.get("coord_counts", None) + mode = model_options.get("generation_mode", "structure_generation") is_512_run = False if mode == "shape_generation_512": @@ -881,7 +884,7 @@ class Trellis2(nn.Module): elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") - slat = transformer_options.get("shape_slat") + slat = model_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 223c69085..013189bb9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -129,6 +129,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["num_heads"] = 12 return unet_config + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture + unet_config = {} + unet_config["image_model"] = "trellis2" + unet_config["resolution"] = 64 + unet_config["num_heads"] = 12 + unet_config["txt_only"] = True + return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 55777e1e9..d2e844561 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -257,7 +257,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae): - resolution = int(vae.resolution.item()) + resolution = int(vae.first_stage_model.resolution.item()) sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() coords = samples["coords"] @@ -322,7 +322,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, mesh, samples, vae, shape_subdivides): shape_mesh = mesh sample_tensor = samples["samples"] - resolution = int(vae.resolution.item()) + resolution = int(vae.first_stage_model.resolution.item()) device = comfy.model_management.get_torch_device() coords = samples["coords"] prepare_trellis_vae_for_decode(vae, sample_tensor.shape) @@ -662,7 +662,6 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): node_id="EmptyTrellis2ShapeLatent", category="latent/3d", inputs=[ - IO.Model.Input("model"), IO.MultiType.Input( "voxel", types=[IO.Voxel, HighResVoxel], @@ -673,13 +672,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): ) ], outputs=[ - IO.Model.Output(), IO.Latent.Output(), ] ) @classmethod - def execute(cls, model, voxel): + def execute(cls, voxel): # to accept the upscaled coords is_512_pass = False @@ -698,19 +696,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): in_channels = 32 # image like format latent = torch.zeros(batch_size, in_channels, max_tokens, 1) - model = model.clone() - model.model_options = model.model_options.copy() - if "transformer_options" in model.model_options: - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() - else: - model.model_options["transformer_options"] = {} - model.model_options["transformer_options"]["coords"] = coords if is_512_pass: - model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" + generation_mode = "shape_generation_512" else: - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}) + generation_mode = "shape_generation" + return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2", + "model_options": {"generation_mode": generation_mode, "coords": coords, "coords_counts": counts}}) class EmptyTrellis2LatentTexture(IO.ComfyNode): @classmethod @@ -719,32 +711,30 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): node_id="EmptyTrellis2LatentTexture", category="latent/3d", inputs=[ - IO.Model.Input("model"), IO.MultiType.Input( - "shape_structure", + "voxel", types=[IO.Voxel, HighResVoxel], tooltip=( "Shape structure input. Accepts either a voxel structure " - "or upsampled coordinates from a previous cascade stage." + "or upsampled voxel coordinates from a previous cascade stage." ) ), IO.Latent.Input("shape_latent"), ], outputs=[ - IO.Model.Output(), IO.Latent.Output(), ] ) @classmethod - def execute(cls, model, shape_structure, shape_latent): + def execute(cls, voxel, shape_latent): channels = 32 - if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4: - decoded = shape_structure.data.unsqueeze(1) + if hasattr(voxel, "data") and voxel.data.ndim == 4: + decoded = voxel.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2: - coords = shape_structure.int() + elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2: + coords = voxel.int() batch_size, counts, max_tokens = infer_batched_coord_layout(coords) @@ -753,17 +743,9 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) latent = torch.zeros(batch_size, channels, max_tokens, 1) - model = model.clone() - model.model_options = model.model_options.copy() - if "transformer_options" in model.model_options: - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() - else: - model.model_options["transformer_options"] = {} - - model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "texture_generation" - model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coords_counts": counts, + "model_options": {"generation_mode": "texture_generation", + "coords": coords, "coords_counts": counts, "shape_slat": shape_latent}}) class EmptyTrellis2LatentStructure(IO.ComfyNode): diff --git a/nodes.py b/nodes.py index 944695e7c..b78d388a6 100644 --- a/nodes.py +++ b/nodes.py @@ -1532,6 +1532,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + if "model_options" in latent: + inner = model.model.diffusion_model + inner.meta = latent["model_options"] + callback = latent_preview.prepare_callback(model, steps) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,