mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 16:27:33 +08:00
updates
This commit is contained in:
parent
6980e15921
commit
ca7fe65e7e
@ -779,15 +779,18 @@ class Trellis2(nn.Module):
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
model_options = {}
|
||||
if hasattr(self, "meta"):
|
||||
model_options = self.meta
|
||||
timestep = timestep.to(x.dtype)
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
|
||||
is_1024 = self.img2shape.resolution == 1024
|
||||
coords = transformer_options.get("coords", None)
|
||||
coord_counts = transformer_options.get("coord_counts", None)
|
||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||
coords = model_options.get("coords", None)
|
||||
coord_counts = model_options.get("coord_counts", None)
|
||||
mode = model_options.get("generation_mode", "structure_generation")
|
||||
|
||||
is_512_run = False
|
||||
if mode == "shape_generation_512":
|
||||
@ -881,7 +884,7 @@ class Trellis2(nn.Module):
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
slat = transformer_options.get("shape_slat")
|
||||
slat = model_options.get("shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
|
||||
@ -129,6 +129,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config["num_heads"] = 12
|
||||
return unet_config
|
||||
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
unet_config["resolution"] = 64
|
||||
unet_config["num_heads"] = 12
|
||||
unet_config["txt_only"] = True
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
|
||||
@ -257,7 +257,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, samples, vae):
|
||||
|
||||
resolution = int(vae.resolution.item())
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
@ -322,7 +322,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
def execute(cls, mesh, samples, vae, shape_subdivides):
|
||||
shape_mesh = mesh
|
||||
sample_tensor = samples["samples"]
|
||||
resolution = int(vae.resolution.item())
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
@ -662,7 +662,6 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
node_id="EmptyTrellis2ShapeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.MultiType.Input(
|
||||
"voxel",
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
@ -673,13 +672,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, voxel):
|
||||
def execute(cls, voxel):
|
||||
# to accept the upscaled coords
|
||||
is_512_pass = False
|
||||
|
||||
@ -698,19 +696,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
in_channels = 32
|
||||
# image like format
|
||||
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
if is_512_pass:
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
||||
generation_mode = "shape_generation_512"
|
||||
else:
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
|
||||
generation_mode = "shape_generation"
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2",
|
||||
"model_options": {"generation_mode": generation_mode, "coords": coords, "coords_counts": counts}})
|
||||
|
||||
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -719,32 +711,30 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
node_id="EmptyTrellis2LatentTexture",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.MultiType.Input(
|
||||
"shape_structure",
|
||||
"voxel",
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled coordinates from a previous cascade stage."
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
)
|
||||
),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, shape_structure, shape_latent):
|
||||
def execute(cls, voxel, shape_latent):
|
||||
channels = 32
|
||||
if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4:
|
||||
decoded = shape_structure.data.unsqueeze(1)
|
||||
if hasattr(voxel, "data") and voxel.data.ndim == 4:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
|
||||
elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2:
|
||||
coords = shape_structure.int()
|
||||
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
|
||||
coords = voxel.int()
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
|
||||
@ -753,17 +743,9 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coords_counts": counts,
|
||||
"model_options": {"generation_mode": "texture_generation",
|
||||
"coords": coords, "coords_counts": counts, "shape_slat": shape_latent}})
|
||||
|
||||
|
||||
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||
|
||||
4
nodes.py
4
nodes.py
@ -1532,6 +1532,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
if "model_options" in latent:
|
||||
inner = model.model.diffusion_model
|
||||
inner.meta = latent["model_options"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user