mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-13 01:30:32 +08:00
updates
This commit is contained in:
parent
6980e15921
commit
ca7fe65e7e
@ -779,15 +779,18 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
model_options = {}
|
||||||
|
if hasattr(self, "meta"):
|
||||||
|
model_options = self.meta
|
||||||
timestep = timestep.to(x.dtype)
|
timestep = timestep.to(x.dtype)
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
if embeds is None:
|
if embeds is None:
|
||||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||||
|
|
||||||
is_1024 = self.img2shape.resolution == 1024
|
is_1024 = self.img2shape.resolution == 1024
|
||||||
coords = transformer_options.get("coords", None)
|
coords = model_options.get("coords", None)
|
||||||
coord_counts = transformer_options.get("coord_counts", None)
|
coord_counts = model_options.get("coord_counts", None)
|
||||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
mode = model_options.get("generation_mode", "structure_generation")
|
||||||
|
|
||||||
is_512_run = False
|
is_512_run = False
|
||||||
if mode == "shape_generation_512":
|
if mode == "shape_generation_512":
|
||||||
@ -881,7 +884,7 @@ class Trellis2(nn.Module):
|
|||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
if self.shape2txt is None:
|
if self.shape2txt is None:
|
||||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||||
slat = transformer_options.get("shape_slat")
|
slat = model_options.get("shape_slat")
|
||||||
if slat is None:
|
if slat is None:
|
||||||
raise ValueError("shape_slat can't be None")
|
raise ValueError("shape_slat can't be None")
|
||||||
|
|
||||||
|
|||||||
@ -129,6 +129,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
unet_config["num_heads"] = 12
|
unet_config["num_heads"] = 12
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["image_model"] = "trellis2"
|
||||||
|
unet_config["resolution"] = 64
|
||||||
|
unet_config["num_heads"] = 12
|
||||||
|
unet_config["txt_only"] = True
|
||||||
|
return unet_config
|
||||||
|
|
||||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
unet_config["audio_model"] = "dit1.0"
|
unet_config["audio_model"] = "dit1.0"
|
||||||
|
|||||||
@ -257,7 +257,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae):
|
def execute(cls, samples, vae):
|
||||||
|
|
||||||
resolution = int(vae.resolution.item())
|
resolution = int(vae.first_stage_model.resolution.item())
|
||||||
sample_tensor = samples["samples"]
|
sample_tensor = samples["samples"]
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
@ -322,7 +322,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, mesh, samples, vae, shape_subdivides):
|
def execute(cls, mesh, samples, vae, shape_subdivides):
|
||||||
shape_mesh = mesh
|
shape_mesh = mesh
|
||||||
sample_tensor = samples["samples"]
|
sample_tensor = samples["samples"]
|
||||||
resolution = int(vae.resolution.item())
|
resolution = int(vae.first_stage_model.resolution.item())
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
@ -662,7 +662,6 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
|||||||
node_id="EmptyTrellis2ShapeLatent",
|
node_id="EmptyTrellis2ShapeLatent",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Model.Input("model"),
|
|
||||||
IO.MultiType.Input(
|
IO.MultiType.Input(
|
||||||
"voxel",
|
"voxel",
|
||||||
types=[IO.Voxel, HighResVoxel],
|
types=[IO.Voxel, HighResVoxel],
|
||||||
@ -673,13 +672,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Model.Output(),
|
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, voxel):
|
def execute(cls, voxel):
|
||||||
# to accept the upscaled coords
|
# to accept the upscaled coords
|
||||||
is_512_pass = False
|
is_512_pass = False
|
||||||
|
|
||||||
@ -698,19 +696,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
|||||||
in_channels = 32
|
in_channels = 32
|
||||||
# image like format
|
# image like format
|
||||||
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
||||||
model = model.clone()
|
|
||||||
model.model_options = model.model_options.copy()
|
|
||||||
if "transformer_options" in model.model_options:
|
|
||||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
|
||||||
else:
|
|
||||||
model.model_options["transformer_options"] = {}
|
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
|
||||||
if is_512_pass:
|
if is_512_pass:
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
generation_mode = "shape_generation_512"
|
||||||
else:
|
else:
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
generation_mode = "shape_generation"
|
||||||
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2",
|
||||||
|
"model_options": {"generation_mode": generation_mode, "coords": coords, "coords_counts": counts}})
|
||||||
|
|
||||||
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -719,32 +711,30 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
|||||||
node_id="EmptyTrellis2LatentTexture",
|
node_id="EmptyTrellis2LatentTexture",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Model.Input("model"),
|
|
||||||
IO.MultiType.Input(
|
IO.MultiType.Input(
|
||||||
"shape_structure",
|
"voxel",
|
||||||
types=[IO.Voxel, HighResVoxel],
|
types=[IO.Voxel, HighResVoxel],
|
||||||
tooltip=(
|
tooltip=(
|
||||||
"Shape structure input. Accepts either a voxel structure "
|
"Shape structure input. Accepts either a voxel structure "
|
||||||
"or upsampled coordinates from a previous cascade stage."
|
"or upsampled voxel coordinates from a previous cascade stage."
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
IO.Latent.Input("shape_latent"),
|
IO.Latent.Input("shape_latent"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Model.Output(),
|
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, shape_structure, shape_latent):
|
def execute(cls, voxel, shape_latent):
|
||||||
channels = 32
|
channels = 32
|
||||||
if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4:
|
if hasattr(voxel, "data") and voxel.data.ndim == 4:
|
||||||
decoded = shape_structure.data.unsqueeze(1)
|
decoded = voxel.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
|
|
||||||
elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2:
|
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
|
||||||
coords = shape_structure.int()
|
coords = voxel.int()
|
||||||
|
|
||||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||||
|
|
||||||
@ -753,17 +743,9 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
|||||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||||
|
|
||||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||||
model = model.clone()
|
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coords_counts": counts,
|
||||||
model.model_options = model.model_options.copy()
|
"model_options": {"generation_mode": "texture_generation",
|
||||||
if "transformer_options" in model.model_options:
|
"coords": coords, "coords_counts": counts, "shape_slat": shape_latent}})
|
||||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
|
||||||
else:
|
|
||||||
model.model_options["transformer_options"] = {}
|
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
|
||||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
|
||||||
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
class EmptyTrellis2LatentStructure(IO.ComfyNode):
|
||||||
|
|||||||
4
nodes.py
4
nodes.py
@ -1532,6 +1532,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
noise_mask = latent["noise_mask"]
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
|
if "model_options" in latent:
|
||||||
|
inner = model.model.diffusion_model
|
||||||
|
inner.meta = latent["model_options"]
|
||||||
|
|
||||||
callback = latent_preview.prepare_callback(model, steps)
|
callback = latent_preview.prepare_callback(model, steps)
|
||||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user