mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
Merge pull request #15 from pollockjj/issue_80
Fix Trellis VAE decode memory management
This commit is contained in:
commit
0bada2e9a2
@ -8,6 +8,21 @@ import torch
|
|||||||
import scipy
|
import scipy
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||||
|
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||||
|
if len(sample_shape) == 5:
|
||||||
|
memory_required *= max(1, int(sample_shape[4]))
|
||||||
|
memory_required = max(1, int(memory_required))
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
comfy.model_management.load_models_gpu(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=memory_required,
|
||||||
|
force_full_load=getattr(vae, "disable_offload", False),
|
||||||
|
)
|
||||||
|
free_memory = vae.patcher.get_free_memory(device)
|
||||||
|
batch_number = max(1, int(free_memory / memory_required))
|
||||||
|
return batch_number
|
||||||
|
|
||||||
|
|
||||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||||
batch_size = len(vertices)
|
batch_size = len(vertices)
|
||||||
@ -271,19 +286,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
|
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
patcher = vae.patcher
|
sample_tensor = samples["samples"]
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(patcher)
|
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
|
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
trellis_vae = vae.first_stage_model
|
||||||
coord_counts = samples.get("coord_counts")
|
coord_counts = samples.get("coord_counts")
|
||||||
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
if coord_counts is None:
|
if coord_counts is None:
|
||||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
samples = shape_norm(samples.to(device), coords.to(device))
|
samples = shape_norm(samples.to(device), coords.to(device))
|
||||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
mesh, subs = trellis_vae.decode_shape_slat(samples, resolution)
|
||||||
else:
|
else:
|
||||||
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
split_items = split_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
mesh = []
|
mesh = []
|
||||||
@ -292,7 +306,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
coords_i = coords_i.to(device).clone()
|
coords_i = coords_i.to(device).clone()
|
||||||
coords_i[:, 0] = 0
|
coords_i[:, 0] = 0
|
||||||
sample_i = shape_norm(feats_i.to(device), coords_i)
|
sample_i = shape_norm(feats_i.to(device), coords_i)
|
||||||
mesh_i, subs_i = vae.decode_shape_slat(sample_i, resolution)
|
mesh_i, subs_i = trellis_vae.decode_shape_slat(sample_i, resolution)
|
||||||
mesh.append(mesh_i[0])
|
mesh.append(mesh_i[0])
|
||||||
subs_per_sample.append(subs_i)
|
subs_per_sample.append(subs_i)
|
||||||
|
|
||||||
@ -332,13 +346,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, shape_mesh, samples, vae, shape_subs, resolution):
|
def execute(cls, shape_mesh, samples, vae, shape_subs, resolution):
|
||||||
|
|
||||||
|
sample_tensor = samples["samples"]
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
patcher = vae.patcher
|
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(patcher)
|
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
|
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
trellis_vae = vae.first_stage_model
|
||||||
coord_counts = samples.get("coord_counts")
|
coord_counts = samples.get("coord_counts")
|
||||||
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
@ -349,7 +362,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
voxel = vae.decode_tex_slat(samples, shape_subs)
|
voxel = trellis_vae.decode_tex_slat(samples, shape_subs)
|
||||||
color_feats = voxel.feats[:, :3]
|
color_feats = voxel.feats[:, :3]
|
||||||
voxel_coords = voxel.coords[:, 1:]
|
voxel_coords = voxel.coords[:, 1:]
|
||||||
voxel_batch_idx = voxel.coords[:, 0]
|
voxel_batch_idx = voxel.coords[:, 0]
|
||||||
@ -397,22 +410,16 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
vae = vae.first_stage_model
|
sample_tensor = samples["samples"]
|
||||||
decoder = vae.struct_dec
|
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
decoder = vae.first_stage_model.struct_dec
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
offload_device = comfy.model_management.vae_offload_device()
|
|
||||||
decoder = decoder.to(load_device)
|
|
||||||
batch_index = normalize_batch_index(samples.get("batch_index"))
|
batch_index = normalize_batch_index(samples.get("batch_index"))
|
||||||
samples = samples["samples"]
|
decoded_batches = []
|
||||||
samples = samples.to(load_device)
|
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||||
if samples.shape[0] > 1:
|
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||||
decoded_items = []
|
decoded_batches.append(decoder(sample_chunk) > 0)
|
||||||
for i in range(samples.shape[0]):
|
decoded = torch.cat(decoded_batches, dim=0)
|
||||||
decoded_items.append(decoder(samples[i:i + 1]) > 0)
|
|
||||||
decoded = torch.cat(decoded_items, dim=0)
|
|
||||||
else:
|
|
||||||
decoded = decoder(samples) > 0
|
|
||||||
decoder.to(offload_device)
|
|
||||||
current_res = decoded.shape[2]
|
current_res = decoded.shape[2]
|
||||||
|
|
||||||
if current_res != resolution:
|
if current_res != resolution:
|
||||||
@ -443,7 +450,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(vae.patcher)
|
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
||||||
|
|
||||||
coord_counts = shape_latent_512.get("coord_counts")
|
coord_counts = shape_latent_512.get("coord_counts")
|
||||||
batch_index = normalize_batch_index(shape_latent_512.get("batch_index"))
|
batch_index = normalize_batch_index(shape_latent_512.get("batch_index"))
|
||||||
|
|||||||
@ -73,6 +73,57 @@ class DummyModel:
|
|||||||
self.model = inner_model
|
self.model = inner_model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPatcher:
|
||||||
|
def __init__(self, free_memory):
|
||||||
|
self.free_memory = free_memory
|
||||||
|
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return self.free_memory
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVAE:
|
||||||
|
vae_dtype = torch.float16
|
||||||
|
|
||||||
|
def __init__(self, free_memory, memory_factor=2):
|
||||||
|
self.patcher = DummyPatcher(free_memory)
|
||||||
|
self.memory_factor = memory_factor
|
||||||
|
|
||||||
|
def memory_used_decode(self, shape, dtype):
|
||||||
|
return shape[2] * shape[3] * self.memory_factor
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrepareTrellisVaeForDecode(unittest.TestCase):
|
||||||
|
def test_uses_load_models_gpu_without_pre_freeing_memory(self):
|
||||||
|
vae = DummyVAE(free_memory=1000)
|
||||||
|
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"):
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory:
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu:
|
||||||
|
batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1))
|
||||||
|
|
||||||
|
free_memory.assert_not_called()
|
||||||
|
load_models_gpu.assert_called_once_with(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=20,
|
||||||
|
force_full_load=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(batch_number, 50)
|
||||||
|
|
||||||
|
def test_scales_memory_estimate_for_5d_structure_latents(self):
|
||||||
|
vae = DummyVAE(free_memory=40960, memory_factor=1)
|
||||||
|
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"):
|
||||||
|
with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu:
|
||||||
|
batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16))
|
||||||
|
|
||||||
|
load_models_gpu.assert_called_once_with(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=4096,
|
||||||
|
force_full_load=False,
|
||||||
|
)
|
||||||
|
self.assertEqual(batch_number, 10)
|
||||||
|
|
||||||
|
|
||||||
class TestRunConditioningRestore(unittest.TestCase):
|
class TestRunConditioningRestore(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.intermediate_patch = patch.object(
|
self.intermediate_patch = patch.object(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user