diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8ff9be011..bc2d6bcab 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,6 +8,21 @@ import torch import scipy import copy +def prepare_trellis_vae_for_decode(vae, sample_shape): + memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) + if len(sample_shape) == 5: + memory_required *= max(1, int(sample_shape[4])) + memory_required = max(1, int(memory_required)) + device = comfy.model_management.get_torch_device() + comfy.model_management.load_models_gpu( + [vae.patcher], + memory_required=memory_required, + force_full_load=getattr(vae, "disable_offload", False), + ) + free_memory = vae.patcher.get_free_memory(device) + batch_number = max(1, int(free_memory / memory_required)) + return batch_number + def pack_variable_mesh_batch(vertices, faces, colors=None): batch_size = len(vertices) @@ -271,19 +286,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): resolution = int(resolution) - patcher = vae.patcher + sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") samples = samples["samples"] if coord_counts is None: samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples = shape_norm(samples.to(device), coords.to(device)) - mesh, subs = vae.decode_shape_slat(samples, resolution) + mesh, subs = trellis_vae.decode_shape_slat(samples, resolution) else: split_items = split_batched_sparse_latent(samples, coords, coord_counts) mesh = [] @@ -292,7 +306,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 sample_i = shape_norm(feats_i.to(device), coords_i) - mesh_i, subs_i = vae.decode_shape_slat(sample_i, resolution) + mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution) mesh.append(mesh_i[0]) subs_per_sample.append(subs_i) @@ -332,13 +346,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, shape_mesh, samples, vae, shape_subs, resolution): + sample_tensor = samples["samples"] resolution = int(resolution) - patcher = vae.patcher device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") samples = samples["samples"] @@ -349,7 +362,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = samples * std + mean - voxel = vae.decode_tex_slat(samples, shape_subs) + voxel = trellis_vae.decode_tex_slat(samples, shape_subs) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] @@ -397,22 +410,16 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae, resolution): resolution = int(resolution) - vae = vae.first_stage_model - decoder = vae.struct_dec + sample_tensor = samples["samples"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + decoder = vae.first_stage_model.struct_dec load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - decoder = decoder.to(load_device) batch_index = normalize_batch_index(samples.get("batch_index")) - samples = samples["samples"] - samples = samples.to(load_device) - if samples.shape[0] > 1: - decoded_items = [] - for i in range(samples.shape[0]): - decoded_items.append(decoder(samples[i:i + 1]) > 0) - decoded = torch.cat(decoded_items, dim=0) - else: - decoded = decoder(samples) > 0 - decoder.to(offload_device) + decoded_batches = [] + for start in range(0, sample_tensor.shape[0], batch_number): + sample_chunk = sample_tensor[start:start + batch_number].to(load_device) + decoded_batches.append(decoder(sample_chunk) > 0) + decoded = torch.cat(decoded_batches, dim=0) current_res = decoded.shape[2] if current_res != resolution: @@ -443,7 +450,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(vae.patcher) + prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) coord_counts = shape_latent_512.get("coord_counts") batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 49e872bc7..96fb4395a 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -73,6 +73,57 @@ class DummyModel: self.model = inner_model +class DummyPatcher: + def __init__(self, free_memory): + self.free_memory = free_memory + + def get_free_memory(self, device): + return self.free_memory + + +class DummyVAE: + vae_dtype = torch.float16 + + def __init__(self, free_memory, memory_factor=2): + self.patcher = DummyPatcher(free_memory) + self.memory_factor = memory_factor + + def memory_used_decode(self, shape, dtype): + return shape[2] * shape[3] * self.memory_factor + + +class TestPrepareTrellisVaeForDecode(unittest.TestCase): + def test_uses_load_models_gpu_without_pre_freeing_memory(self): + vae = DummyVAE(free_memory=1000) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory: + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1)) + + free_memory.assert_not_called() + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=20, + force_full_load=False, + ) + self.assertEqual(batch_number, 50) + + def test_scales_memory_estimate_for_5d_structure_latents(self): + vae = DummyVAE(free_memory=40960, memory_factor=1) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16)) + + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=4096, + force_full_load=False, + ) + self.assertEqual(batch_number, 10) + + class TestRunConditioningRestore(unittest.TestCase): def setUp(self): self.intermediate_patch = patch.object(