mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
Address Trellis VAE decode review feedback
This commit is contained in:
parent
c1fa56251e
commit
8816699e7c
@ -9,9 +9,11 @@ import scipy
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||||
memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype)))
|
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||||
|
if len(sample_shape) == 5:
|
||||||
|
memory_required *= max(1, int(sample_shape[4]))
|
||||||
|
memory_required = max(1, int(memory_required))
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.free_memory(memory_required, device, for_dynamic=False)
|
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[vae.patcher],
|
[vae.patcher],
|
||||||
memory_required=memory_required,
|
memory_required=memory_required,
|
||||||
@ -19,7 +21,7 @@ def prepare_trellis_vae_for_decode(vae, sample_shape):
|
|||||||
)
|
)
|
||||||
free_memory = vae.patcher.get_free_memory(device)
|
free_memory = vae.patcher.get_free_memory(device)
|
||||||
batch_number = max(1, int(free_memory / memory_required))
|
batch_number = max(1, int(free_memory / memory_required))
|
||||||
return min(sample_shape[0], batch_number)
|
return batch_number
|
||||||
|
|
||||||
|
|
||||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||||
|
|||||||
@ -73,6 +73,57 @@ class DummyModel:
|
|||||||
self.model = inner_model
|
self.model = inner_model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPatcher:
|
||||||
|
def __init__(self, free_memory):
|
||||||
|
self.free_memory = free_memory
|
||||||
|
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return self.free_memory
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVAE:
|
||||||
|
vae_dtype = torch.float16
|
||||||
|
|
||||||
|
def __init__(self, free_memory, memory_factor=2):
|
||||||
|
self.patcher = DummyPatcher(free_memory)
|
||||||
|
self.memory_factor = memory_factor
|
||||||
|
|
||||||
|
def memory_used_decode(self, shape, dtype):
|
||||||
|
return shape[2] * shape[3] * self.memory_factor
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrepareTrellisVaeForDecode(unittest.TestCase):
|
||||||
|
def test_uses_load_models_gpu_without_pre_freeing_memory(self):
|
||||||
|
vae = DummyVAE(free_memory=1000)
|
||||||
|
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"):
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory:
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu:
|
||||||
|
batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1))
|
||||||
|
|
||||||
|
free_memory.assert_not_called()
|
||||||
|
load_models_gpu.assert_called_once_with(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=20,
|
||||||
|
force_full_load=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(batch_number, 50)
|
||||||
|
|
||||||
|
def test_scales_memory_estimate_for_5d_structure_latents(self):
|
||||||
|
vae = DummyVAE(free_memory=40960, memory_factor=1)
|
||||||
|
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"):
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu:
|
||||||
|
batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16))
|
||||||
|
|
||||||
|
load_models_gpu.assert_called_once_with(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=4096,
|
||||||
|
force_full_load=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(batch_number, 10)
|
||||||
|
|
||||||
|
|
||||||
class TestRunConditioningRestore(unittest.TestCase):
|
class TestRunConditioningRestore(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.intermediate_patch = patch.object(
|
self.intermediate_patch = patch.object(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user