diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..89fb2443e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,6 +1,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types -from comfy.ldm.trellis2.vae import SparseTensor +from comfy.ldm.trellis2.vae import SparseTensor, sparse_cat import comfy.model_management from PIL import Image import numpy as np @@ -8,6 +8,25 @@ import torch import scipy import copy +def prepare_trellis_vae_for_decode(vae, sample_shape): + memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype))) + device = comfy.model_management.get_torch_device() + comfy.model_management.free_memory(memory_required, device, for_dynamic=False) + comfy.model_management.load_models_gpu( + [vae.patcher], + memory_required=memory_required, + force_full_load=getattr(vae, "disable_offload", False), + ) + free_memory = vae.patcher.get_free_memory(device) + batch_number = max(1, int(free_memory / memory_required)) + return min(sample_shape[0], batch_number) + + +def combine_sparse_sub_batches(sub_batches): + if len(sub_batches) == 1: + return sub_batches[0] + return [sparse_cat([batch[level] for batch in sub_batches], dim=0) for level in range(len(sub_batches[0]))] + def pack_variable_mesh_batch(vertices, faces, colors=None): batch_size = len(vertices) @@ -163,18 +182,24 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): resolution = int(resolution) - patcher = vae.patcher + sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model - samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - samples = shape_norm(samples, coords) + shape_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + shape_latent = shape_norm(shape_samples, coords.to(device)) - mesh, subs = vae.decode_shape_slat(samples, resolution) + mesh = [] + sub_batches = [] + for start in range(0, shape_latent.shape[0], batch_number): + end = start + batch_number + mesh_chunk, subs_chunk = trellis_vae.decode_shape_slat(shape_latent[start:end], resolution) + mesh.extend(mesh_chunk) + sub_batches.append(subs_chunk) + + subs = combine_sparse_sub_batches(sub_batches) face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -204,21 +229,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, shape_mesh, samples, vae, shape_subs): resolution = 1024 - patcher = vae.patcher + sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model - samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - std = tex_slat_normalization["std"].to(samples) - mean = tex_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) - samples = samples * std + mean + tex_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + std = tex_slat_normalization["std"].to(tex_samples) + mean = tex_slat_normalization["mean"].to(tex_samples) + tex_latent = SparseTensor(feats=tex_samples, coords=coords.to(device)) + tex_latent = tex_latent * std + mean - voxel = vae.decode_tex_slat(samples, shape_subs) + voxel_batches = [] + for start in range(0, tex_latent.shape[0], batch_number): + end = start + batch_number + guide_subs = [sub[start:end] for sub in shape_subs] + voxel_batches.append(trellis_vae.decode_tex_slat(tex_latent[start:end], guide_subs)) + voxel = voxel_batches[0] if len(voxel_batches) == 1 else sparse_cat(voxel_batches, dim=0) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] @@ -266,15 +294,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae, resolution): resolution = int(resolution) - vae = vae.first_stage_model - decoder = vae.struct_dec + sample_tensor = samples["samples"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + decoder = vae.first_stage_model.struct_dec load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - decoder = decoder.to(load_device) - samples = samples["samples"] - samples = samples.to(load_device) - decoded = decoder(samples)>0 - decoder.to(offload_device) + decoded_batches = [] + for start in range(0, sample_tensor.shape[0], batch_number): + sample_chunk = sample_tensor[start:start + batch_number].to(load_device) + decoded_batches.append(decoder(sample_chunk) > 0) + decoded = torch.cat(decoded_batches, dim=0) current_res = decoded.shape[2] if current_res != resolution: @@ -303,7 +331,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(vae.patcher) + prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) coords_512 = shape_latent_512["coords"].to(device)