mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 12:22:37 +08:00
Fix Trellis VAE decode memory management
This commit is contained in:
parent
880d7823e8
commit
f15bf73d5c
@ -1,6 +1,6 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor, sparse_cat
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,6 +8,25 @@ import torch
|
|||||||
import scipy
|
import scipy
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||||
|
memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype)))
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
comfy.model_management.free_memory(memory_required, device, for_dynamic=False)
|
||||||
|
comfy.model_management.load_models_gpu(
|
||||||
|
[vae.patcher],
|
||||||
|
memory_required=memory_required,
|
||||||
|
force_full_load=getattr(vae, "disable_offload", False),
|
||||||
|
)
|
||||||
|
free_memory = vae.patcher.get_free_memory(device)
|
||||||
|
batch_number = max(1, int(free_memory / memory_required))
|
||||||
|
return min(sample_shape[0], batch_number)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_sparse_sub_batches(sub_batches):
|
||||||
|
if len(sub_batches) == 1:
|
||||||
|
return sub_batches[0]
|
||||||
|
return [sparse_cat([batch[level] for batch in sub_batches], dim=0) for level in range(len(sub_batches[0]))]
|
||||||
|
|
||||||
|
|
||||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||||
batch_size = len(vertices)
|
batch_size = len(vertices)
|
||||||
@ -163,18 +182,24 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
|
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
patcher = vae.patcher
|
sample_tensor = samples["samples"]
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(patcher)
|
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
|
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
trellis_vae = vae.first_stage_model
|
||||||
|
|
||||||
samples = samples["samples"]
|
shape_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
shape_latent = shape_norm(shape_samples, coords.to(device))
|
||||||
samples = shape_norm(samples, coords)
|
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
mesh = []
|
||||||
|
sub_batches = []
|
||||||
|
for start in range(0, shape_latent.shape[0], batch_number):
|
||||||
|
end = start + batch_number
|
||||||
|
mesh_chunk, subs_chunk = trellis_vae.decode_shape_slat(shape_latent[start:end], resolution)
|
||||||
|
mesh.extend(mesh_chunk)
|
||||||
|
sub_batches.append(subs_chunk)
|
||||||
|
|
||||||
|
subs = combine_sparse_sub_batches(sub_batches)
|
||||||
face_list = [m.faces for m in mesh]
|
face_list = [m.faces for m in mesh]
|
||||||
vert_list = [m.vertices for m in mesh]
|
vert_list = [m.vertices for m in mesh]
|
||||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||||
@ -204,21 +229,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
def execute(cls, shape_mesh, samples, vae, shape_subs):
|
def execute(cls, shape_mesh, samples, vae, shape_subs):
|
||||||
|
|
||||||
resolution = 1024
|
resolution = 1024
|
||||||
patcher = vae.patcher
|
sample_tensor = samples["samples"]
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(patcher)
|
|
||||||
|
|
||||||
vae = vae.first_stage_model
|
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
|
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
trellis_vae = vae.first_stage_model
|
||||||
|
|
||||||
samples = samples["samples"]
|
tex_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
std = tex_slat_normalization["std"].to(tex_samples)
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
mean = tex_slat_normalization["mean"].to(tex_samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
tex_latent = SparseTensor(feats=tex_samples, coords=coords.to(device))
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
tex_latent = tex_latent * std + mean
|
||||||
samples = samples * std + mean
|
|
||||||
|
|
||||||
voxel = vae.decode_tex_slat(samples, shape_subs)
|
voxel_batches = []
|
||||||
|
for start in range(0, tex_latent.shape[0], batch_number):
|
||||||
|
end = start + batch_number
|
||||||
|
guide_subs = [sub[start:end] for sub in shape_subs]
|
||||||
|
voxel_batches.append(trellis_vae.decode_tex_slat(tex_latent[start:end], guide_subs))
|
||||||
|
voxel = voxel_batches[0] if len(voxel_batches) == 1 else sparse_cat(voxel_batches, dim=0)
|
||||||
color_feats = voxel.feats[:, :3]
|
color_feats = voxel.feats[:, :3]
|
||||||
voxel_coords = voxel.coords[:, 1:]
|
voxel_coords = voxel.coords[:, 1:]
|
||||||
voxel_batch_idx = voxel.coords[:, 0]
|
voxel_batch_idx = voxel.coords[:, 0]
|
||||||
@ -266,15 +294,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, vae, resolution):
|
def execute(cls, samples, vae, resolution):
|
||||||
resolution = int(resolution)
|
resolution = int(resolution)
|
||||||
vae = vae.first_stage_model
|
sample_tensor = samples["samples"]
|
||||||
decoder = vae.struct_dec
|
batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
|
decoder = vae.first_stage_model.struct_dec
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
offload_device = comfy.model_management.vae_offload_device()
|
decoded_batches = []
|
||||||
decoder = decoder.to(load_device)
|
for start in range(0, sample_tensor.shape[0], batch_number):
|
||||||
samples = samples["samples"]
|
sample_chunk = sample_tensor[start:start + batch_number].to(load_device)
|
||||||
samples = samples.to(load_device)
|
decoded_batches.append(decoder(sample_chunk) > 0)
|
||||||
decoded = decoder(samples)>0
|
decoded = torch.cat(decoded_batches, dim=0)
|
||||||
decoder.to(offload_device)
|
|
||||||
current_res = decoded.shape[2]
|
current_res = decoded.shape[2]
|
||||||
|
|
||||||
if current_res != resolution:
|
if current_res != resolution:
|
||||||
@ -303,7 +331,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
comfy.model_management.load_model_gpu(vae.patcher)
|
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
|
||||||
|
|
||||||
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
coords_512 = shape_latent_512["coords"].to(device)
|
coords_512 = shape_latent_512["coords"].to(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user