mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
removing seeds from node display
This commit is contained in:
parent
94adce93ab
commit
9d0f678f6f
@ -802,6 +802,11 @@ class Trellis2(nn.Module):
|
|||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
not_struct_mode = False
|
not_struct_mode = False
|
||||||
|
|
||||||
|
if not not_struct_mode:
|
||||||
|
bsz = x.size(0)
|
||||||
|
x = x[:, :8]
|
||||||
|
x = x.view(bsz, 8, 16, 16, 16)
|
||||||
|
|
||||||
if is_1024 and not_struct_mode and not is_512_run:
|
if is_1024 and not_struct_mode and not is_512_run:
|
||||||
context = embeds
|
context = embeds
|
||||||
|
|
||||||
@ -821,7 +826,7 @@ class Trellis2(nn.Module):
|
|||||||
orig_bsz = x.shape[0]
|
orig_bsz = x.shape[0]
|
||||||
rule = txt_rule if mode == "texture_generation" else shape_rule
|
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||||
|
|
||||||
# 1. CFG Bypass Slicing
|
# CFG Bypass Slicing
|
||||||
if rule and orig_bsz > 1:
|
if rule and orig_bsz > 1:
|
||||||
half = orig_bsz // 2
|
half = orig_bsz // 2
|
||||||
x_eval = x[half:]
|
x_eval = x[half:]
|
||||||
@ -834,7 +839,7 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
B, N, C = x_eval.shape
|
B, N, C = x_eval.shape
|
||||||
|
|
||||||
# 2. Vectorized SparseTensor Construction (NO FOR LOOPS!)
|
# Vectorized SparseTensor Construction
|
||||||
if mode in ["shape_generation", "texture_generation"]:
|
if mode in ["shape_generation", "texture_generation"]:
|
||||||
if coord_counts is not None:
|
if coord_counts is not None:
|
||||||
logical_batch = coord_counts.shape[0]
|
logical_batch = coord_counts.shape[0]
|
||||||
@ -880,14 +885,14 @@ class Trellis2(nn.Module):
|
|||||||
if slat is None:
|
if slat is None:
|
||||||
raise ValueError("shape_slat can't be None")
|
raise ValueError("shape_slat can't be None")
|
||||||
|
|
||||||
slat_feats = slat.feats
|
slat_feats = slat
|
||||||
# Duplicate shape context if CFG is active
|
# Duplicate shape context if CFG is active
|
||||||
if coord_counts is not None and B > coord_counts.shape[0]:
|
if coord_counts is not None and B > coord_counts.shape[0]:
|
||||||
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
|
slat_feats = torch.cat([slat_feats, slat_feats], dim=0)
|
||||||
elif coord_counts is None:
|
elif coord_counts is None:
|
||||||
slat_feats = slat.feats[:N].repeat(B, 1)
|
slat_feats = slat_feats[:N].repeat(B, 1)
|
||||||
|
|
||||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1))
|
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
|
||||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||||
|
|
||||||
else: # structure
|
else: # structure
|
||||||
@ -901,9 +906,6 @@ class Trellis2(nn.Module):
|
|||||||
else:
|
else:
|
||||||
out = self.structure_model(x, timestep, context)
|
out = self.structure_model(x, timestep, context)
|
||||||
|
|
||||||
# ==================================================
|
|
||||||
# RE-PAD AND FORMAT OUTPUT
|
|
||||||
# ==================================================
|
|
||||||
if not_struct_mode:
|
if not_struct_mode:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# Instantly scatter the valid tokens back into a padded rectangular tensor
|
# Instantly scatter the valid tokens back into a padded rectangular tensor
|
||||||
@ -916,7 +918,7 @@ class Trellis2(nn.Module):
|
|||||||
if rule and orig_bsz > 1:
|
if rule and orig_bsz > 1:
|
||||||
out_tensor = out_tensor.repeat(2, 1, 1, 1)
|
out_tensor = out_tensor.repeat(2, 1, 1, 1)
|
||||||
return out_tensor
|
return out_tensor
|
||||||
#else:
|
else:
|
||||||
# out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0))
|
out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -159,37 +159,6 @@ def split_batched_coords(coords, coord_counts):
|
|||||||
items.append(coords_i)
|
items.append(coords_i)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def normalize_batch_index(batch_index):
|
|
||||||
if batch_index is None:
|
|
||||||
return None
|
|
||||||
if isinstance(batch_index, int):
|
|
||||||
return [int(batch_index)]
|
|
||||||
return list(batch_index)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_sample_indices(batch_index, batch_size):
|
|
||||||
sample_indices = normalize_batch_index(batch_index)
|
|
||||||
if sample_indices is None:
|
|
||||||
return list(range(batch_size))
|
|
||||||
if len(sample_indices) != batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}"
|
|
||||||
)
|
|
||||||
return sample_indices
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_singleton_sample_index(batch_index):
|
|
||||||
sample_indices = normalize_batch_index(batch_index)
|
|
||||||
if sample_indices is None:
|
|
||||||
return 0
|
|
||||||
if len(sample_indices) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}"
|
|
||||||
)
|
|
||||||
return int(sample_indices[0])
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_batched_sparse_latent(samples, coords, coord_counts):
|
def flatten_batched_sparse_latent(samples, coords, coord_counts):
|
||||||
samples = samples.squeeze(-1).transpose(1, 2)
|
samples = samples.squeeze(-1).transpose(1, 2)
|
||||||
if coord_counts is None:
|
if coord_counts is None:
|
||||||
@ -218,7 +187,6 @@ def split_batched_sparse_latent(samples, coords, coord_counts):
|
|||||||
items.append((samples[i, :count], coords_i))
|
items.append((samples[i, :count], coords_i))
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
||||||
"""
|
"""
|
||||||
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
|
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
|
||||||
@ -232,15 +200,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
|||||||
# map voxels
|
# map voxels
|
||||||
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
|
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
|
||||||
verts = mesh.vertices.to(device).squeeze(0)
|
verts = mesh.vertices.to(device).squeeze(0)
|
||||||
voxel_colors = voxel_colors.cpu()
|
voxel_colors = voxel_colors.to(device)
|
||||||
|
|
||||||
voxel_pos_np = voxel_pos.cpu().numpy()
|
voxel_pos_np = voxel_pos.numpy()
|
||||||
verts_np = verts.cpu().numpy()
|
verts_np = verts.numpy()
|
||||||
|
|
||||||
tree = scipy.spatial.cKDTree(voxel_pos_np)
|
tree = scipy.spatial.cKDTree(voxel_pos_np)
|
||||||
|
|
||||||
# nearest neighbour k=1
|
# nearest neighbour k=1
|
||||||
_, nearest_idx_np = tree.query(verts_np, k=1, workers=1)
|
_, nearest_idx_np = tree.query(verts_np, k=1, workers=-1)
|
||||||
|
|
||||||
nearest_idx = torch.from_numpy(nearest_idx_np).long()
|
nearest_idx = torch.from_numpy(nearest_idx_np).long()
|
||||||
v_colors = voxel_colors[nearest_idx]
|
v_colors = voxel_colors[nearest_idx]
|
||||||
@ -253,7 +221,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
|||||||
|
|
||||||
final_colors = linear_colors.unsqueeze(0)
|
final_colors = linear_colors.unsqueeze(0)
|
||||||
|
|
||||||
out_mesh = copy.copy(mesh)
|
out_mesh = copy.deepcopy(mesh)
|
||||||
out_mesh.colors = final_colors
|
out_mesh.colors = final_colors
|
||||||
|
|
||||||
return out_mesh
|
return out_mesh
|
||||||
@ -411,10 +379,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
sample_tensor = samples["samples"]
|
sample_tensor = samples["samples"]
|
||||||
|
sample_tensor = sample_tensor[:, :8]
|
||||||
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
decoder = vae.first_stage_model.struct_dec
|
decoder = vae.first_stage_model.struct_dec
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
batch_index = normalize_batch_index(samples.get("batch_index"))
|
|
||||||
decoded_batches = []
|
decoded_batches = []
|
||||||
for start in range(0, sample_tensor.shape[0], batch_number):
|
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||||
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||||
@ -426,8 +394,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
ratio = current_res // resolution
|
ratio = current_res // resolution
|
||||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||||
if batch_index is not None:
|
|
||||||
out.batch_index = normalize_batch_index(batch_index)
|
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||||
@ -453,7 +419,6 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
||||||
|
|
||||||
coord_counts = shape_latent_512.get("coord_counts")
|
coord_counts = shape_latent_512.get("coord_counts")
|
||||||
batch_index = normalize_batch_index(shape_latent_512.get("batch_index"))
|
|
||||||
decoder = vae.first_stage_model.shape_dec
|
decoder = vae.first_stage_model.shape_dec
|
||||||
lr_resolution = 512
|
lr_resolution = 512
|
||||||
target_resolution = int(target_resolution)
|
target_resolution = int(target_resolution)
|
||||||
@ -529,14 +494,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
final_coords_list.append(final_coords_i)
|
final_coords_list.append(final_coords_i)
|
||||||
output_coord_counts.append(int(final_coords_i.shape[0]))
|
output_coord_counts.append(int(final_coords_i.shape[0]))
|
||||||
|
|
||||||
normalized_batch_index = normalize_batch_index(batch_index)
|
|
||||||
output = {
|
output = {
|
||||||
"coords": torch.cat(final_coords_list, dim=0),
|
"coords": torch.cat(final_coords_list, dim=0),
|
||||||
"coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64),
|
"coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64),
|
||||||
"resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64),
|
"resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64),
|
||||||
}
|
}
|
||||||
if normalized_batch_index is not None:
|
|
||||||
output["batch_index"] = normalized_batch_index
|
|
||||||
|
|
||||||
return IO.NodeOutput(output,)
|
return IO.NodeOutput(output,)
|
||||||
|
|
||||||
@ -547,8 +509,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
|||||||
model_internal = model.model
|
model_internal = model.model
|
||||||
device = comfy.model_management.intermediate_device()
|
device = comfy.model_management.intermediate_device()
|
||||||
torch_device = comfy.model_management.get_torch_device()
|
torch_device = comfy.model_management.get_torch_device()
|
||||||
had_image_size = hasattr(model_internal, "image_size")
|
|
||||||
original_image_size = getattr(model_internal, "image_size", None)
|
|
||||||
|
|
||||||
def prepare_tensor(pil_img, size):
|
def prepare_tensor(pil_img, size):
|
||||||
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||||
@ -556,21 +516,15 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
|||||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||||
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
cond_1024 = None
|
model_internal.image_size = 512
|
||||||
try:
|
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||||
model_internal.image_size = 512
|
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
|
||||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
|
||||||
|
|
||||||
if include_1024:
|
cond_1024 = None
|
||||||
model_internal.image_size = 1024
|
if include_1024:
|
||||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
model_internal.image_size = 1024
|
||||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||||
finally:
|
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||||
if not had_image_size:
|
|
||||||
delattr(model_internal, "image_size")
|
|
||||||
else:
|
|
||||||
model_internal.image_size = original_image_size
|
|
||||||
|
|
||||||
conditioning = {
|
conditioning = {
|
||||||
'cond_512': cond_512.to(device),
|
'cond_512': cond_512.to(device),
|
||||||
@ -580,7 +534,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
|||||||
conditioning['cond_1024'] = cond_1024.to(device)
|
conditioning['cond_1024'] = cond_1024.to(device)
|
||||||
|
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -693,7 +646,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.AnyType.Input("structure_or_coords"),
|
IO.AnyType.Input("structure_or_coords"),
|
||||||
IO.Model.Input("model"),
|
IO.Model.Input("model"),
|
||||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -702,58 +654,25 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_or_coords, model, seed):
|
def execute(cls, structure_or_coords, model):
|
||||||
# to accept the upscaled coords
|
# to accept the upscaled coords
|
||||||
is_512_pass = False
|
is_512_pass = False
|
||||||
coord_counts = None
|
|
||||||
coord_resolutions = None
|
|
||||||
batch_index = None
|
|
||||||
|
|
||||||
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||||
decoded = structure_or_coords.data.unsqueeze(1)
|
decoded = structure_or_coords.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
is_512_pass = True
|
is_512_pass = True
|
||||||
batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None))
|
|
||||||
|
|
||||||
elif isinstance(structure_or_coords, dict):
|
|
||||||
coords = structure_or_coords["coords"].int()
|
|
||||||
coord_counts = structure_or_coords.get("coord_counts")
|
|
||||||
coord_resolutions = structure_or_coords.get("resolutions")
|
|
||||||
batch_index = normalize_batch_index(structure_or_coords.get("batch_index"))
|
|
||||||
is_512_pass = False
|
|
||||||
|
|
||||||
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||||
coords = structure_or_coords.int()
|
coords = structure_or_coords.int()
|
||||||
is_512_pass = False
|
is_512_pass = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}")
|
raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}")
|
||||||
|
|
||||||
|
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
|
# image like format
|
||||||
if coord_counts is not None:
|
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
||||||
coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device)
|
|
||||||
if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts):
|
|
||||||
raise ValueError(
|
|
||||||
f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
coord_counts = inferred_coord_counts
|
|
||||||
if batch_size == 1:
|
|
||||||
sample_index = resolve_singleton_sample_index(batch_index)
|
|
||||||
generator = torch.Generator(device="cpu")
|
|
||||||
generator.manual_seed(int(seed) + sample_index)
|
|
||||||
latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator)
|
|
||||||
else:
|
|
||||||
sample_indices = resolve_sample_indices(batch_index, batch_size)
|
|
||||||
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
|
|
||||||
for i, sample_index in enumerate(sample_indices):
|
|
||||||
count = int(coord_counts[i].item())
|
|
||||||
generator = torch.Generator(device="cpu")
|
|
||||||
generator.manual_seed(int(seed) + int(sample_index))
|
|
||||||
latent_i = torch.randn(1, in_channels, count, 1, generator=generator)
|
|
||||||
latent[i, :, :count] = latent_i[0]
|
|
||||||
if coord_counts is not None:
|
|
||||||
latent.trellis_coord_counts = coord_counts.clone()
|
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options = model.model_options.copy()
|
||||||
if "transformer_options" in model.model_options:
|
if "transformer_options" in model.model_options:
|
||||||
@ -762,20 +681,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
model.model_options["transformer_options"] = {}
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
if coord_counts is not None:
|
|
||||||
model.model_options["transformer_options"]["coord_counts"] = coord_counts
|
|
||||||
if is_512_pass:
|
if is_512_pass:
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
||||||
else:
|
else:
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
output = {"samples": latent, "coords": coords, "type": "trellis2"}
|
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model)
|
||||||
if batch_index is not None:
|
|
||||||
output["batch_index"] = normalize_batch_index(batch_index)
|
|
||||||
if coord_counts is not None:
|
|
||||||
output["coord_counts"] = coord_counts
|
|
||||||
if coord_resolutions is not None:
|
|
||||||
output["resolutions"] = coord_resolutions
|
|
||||||
return IO.NodeOutput(output, model)
|
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -787,7 +697,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
IO.Voxel.Input("structure_or_coords"),
|
IO.Voxel.Input("structure_or_coords"),
|
||||||
IO.Latent.Input("shape_latent"),
|
IO.Latent.Input("shape_latent"),
|
||||||
IO.Model.Input("model"),
|
IO.Model.Input("model"),
|
||||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -796,68 +705,22 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_or_coords, shape_latent, model, seed):
|
def execute(cls, structure_or_coords, shape_latent, model):
|
||||||
channels = 32
|
channels = 32
|
||||||
coord_counts = None
|
|
||||||
batch_index = None
|
|
||||||
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||||
decoded = structure_or_coords.data.unsqueeze(1)
|
decoded = structure_or_coords.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None))
|
|
||||||
|
|
||||||
elif isinstance(structure_or_coords, dict):
|
|
||||||
coords = structure_or_coords["coords"].int()
|
|
||||||
coord_counts = structure_or_coords.get("coord_counts")
|
|
||||||
batch_index = normalize_batch_index(structure_or_coords.get("batch_index"))
|
|
||||||
|
|
||||||
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||||
coords = structure_or_coords.int()
|
coords = structure_or_coords.int()
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"structure_or_coords must be a voxel input with data.ndim == 4, "
|
|
||||||
f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}'
|
|
||||||
)
|
|
||||||
|
|
||||||
shape_batch_index = normalize_batch_index(shape_latent.get("batch_index"))
|
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||||
if batch_index is None:
|
|
||||||
batch_index = shape_batch_index
|
|
||||||
shape_latent = shape_latent["samples"]
|
shape_latent = shape_latent["samples"]
|
||||||
batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords)
|
|
||||||
if coord_counts is not None:
|
|
||||||
coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device)
|
|
||||||
if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts):
|
|
||||||
raise ValueError(
|
|
||||||
f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
coord_counts = inferred_coord_counts
|
|
||||||
if shape_latent.ndim == 4:
|
if shape_latent.ndim == 4:
|
||||||
if shape_latent.shape[0] != batch_size:
|
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||||
raise ValueError(
|
|
||||||
f"shape_latent batch {shape_latent.shape[0]} doesn't match coords batch {batch_size}"
|
|
||||||
)
|
|
||||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2)
|
|
||||||
if shape_latent.shape[1] < max_tokens:
|
|
||||||
raise ValueError(
|
|
||||||
f"shape_latent tokens {shape_latent.shape[1]} can't cover coords max tokens {max_tokens}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_size == 1:
|
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
||||||
sample_index = resolve_singleton_sample_index(batch_index)
|
|
||||||
generator = torch.Generator(device="cpu")
|
|
||||||
generator.manual_seed(int(seed) + sample_index)
|
|
||||||
latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator)
|
|
||||||
else:
|
|
||||||
sample_indices = resolve_sample_indices(batch_index, batch_size)
|
|
||||||
latent = torch.zeros(batch_size, channels, max_tokens, 1)
|
|
||||||
for i, sample_index in enumerate(sample_indices):
|
|
||||||
count = int(coord_counts[i].item())
|
|
||||||
generator = torch.Generator(device="cpu")
|
|
||||||
generator.manual_seed(int(seed) + int(sample_index))
|
|
||||||
latent_i = torch.randn(1, channels, count, 1, generator=generator)
|
|
||||||
latent[i, :, :count] = latent_i[0]
|
|
||||||
if coord_counts is not None:
|
|
||||||
latent.trellis_coord_counts = coord_counts.clone()
|
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options = model.model_options.copy()
|
||||||
if "transformer_options" in model.model_options:
|
if "transformer_options" in model.model_options:
|
||||||
@ -866,16 +729,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
model.model_options["transformer_options"] = {}
|
model.model_options["transformer_options"] = {}
|
||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
if coord_counts is not None:
|
|
||||||
model.model_options["transformer_options"]["coord_counts"] = coord_counts
|
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
||||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||||
output = {"samples": latent, "coords": coords, "type": "trellis2"}
|
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model)
|
||||||
if batch_index is not None:
|
|
||||||
output["batch_index"] = normalize_batch_index(batch_index)
|
|
||||||
if coord_counts is not None:
|
|
||||||
output["coord_counts"] = coord_counts
|
|
||||||
return IO.NodeOutput(output, model)
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -886,29 +742,20 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||||
IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."),
|
|
||||||
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, batch_size, batch_index_start, seed):
|
def execute(cls, batch_size):
|
||||||
in_channels = 8
|
in_channels = 8
|
||||||
resolution = 16
|
resolution = 16
|
||||||
sample_indices = [int(batch_index_start) + i for i in range(batch_size)]
|
|
||||||
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
for i, sample_index in enumerate(sample_indices):
|
|
||||||
generator = torch.Generator(device="cpu")
|
|
||||||
generator.manual_seed(int(seed) + sample_index)
|
|
||||||
latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0]
|
|
||||||
output = {
|
output = {
|
||||||
"samples": latent,
|
"samples": latent,
|
||||||
"type": "trellis2",
|
"type": "trellis2",
|
||||||
}
|
}
|
||||||
if batch_size > 1 or batch_index_start != 0:
|
|
||||||
output["batch_index"] = sample_indices
|
|
||||||
return IO.NodeOutput(output)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None):
|
def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None):
|
||||||
@ -939,7 +786,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
out_v, out_f, out_c = _qem_simplify_robust(
|
out_v, out_f, out_c = _qem_simplify(
|
||||||
verts_np, faces_np, colors_np, target, device, max_edge_length
|
verts_np, faces_np, colors_np, target, device, max_edge_length
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -952,7 +799,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non
|
|||||||
)
|
)
|
||||||
return final_v, final_f, final_c
|
return final_v, final_f, final_c
|
||||||
|
|
||||||
def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None):
|
def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None):
|
||||||
verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64)
|
verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64)
|
||||||
faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64)
|
faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64)
|
||||||
colors = (
|
colors = (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user