From 8816699e7c2b4d1c5c8d3595541928e92026677a Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 22:10:15 -0500 Subject: [PATCH] Address Trellis VAE decode review feedback --- comfy_extras/nodes_trellis2.py | 8 +-- .../comfy_extras_test/nodes_trellis2_test.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 397453562..bc2d6bcab 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -9,9 +9,11 @@ import scipy import copy def prepare_trellis_vae_for_decode(vae, sample_shape): - memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype))) + memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) + if len(sample_shape) == 5: + memory_required *= max(1, int(sample_shape[4])) + memory_required = max(1, int(memory_required)) device = comfy.model_management.get_torch_device() - comfy.model_management.free_memory(memory_required, device, for_dynamic=False) comfy.model_management.load_models_gpu( [vae.patcher], memory_required=memory_required, @@ -19,7 +21,7 @@ def prepare_trellis_vae_for_decode(vae, sample_shape): ) free_memory = vae.patcher.get_free_memory(device) batch_number = max(1, int(free_memory / memory_required)) - return min(sample_shape[0], batch_number) + return batch_number def pack_variable_mesh_batch(vertices, faces, colors=None): diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 49e872bc7..96fb4395a 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -73,6 +73,57 @@ class DummyModel: self.model = inner_model +class DummyPatcher: + def __init__(self, free_memory): + self.free_memory = free_memory + + def get_free_memory(self, device): + return self.free_memory + + +class DummyVAE: + vae_dtype = torch.float16 + + def __init__(self, free_memory, memory_factor=2): + self.patcher = DummyPatcher(free_memory) + self.memory_factor = memory_factor + + def memory_used_decode(self, shape, dtype): + return shape[2] * shape[3] * self.memory_factor + + +class TestPrepareTrellisVaeForDecode(unittest.TestCase): + def test_uses_load_models_gpu_without_pre_freeing_memory(self): + vae = DummyVAE(free_memory=1000) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory: + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1)) + + free_memory.assert_not_called() + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=20, + force_full_load=False, + ) + self.assertEqual(batch_number, 50) + + def test_scales_memory_estimate_for_5d_structure_latents(self): + vae = DummyVAE(free_memory=40960, memory_factor=1) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16)) + + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=4096, + force_full_load=False, + ) + self.assertEqual(batch_number, 10) + + class TestRunConditioningRestore(unittest.TestCase): def setUp(self): self.intermediate_patch = patch.object(