mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 23:12:35 +08:00
large optimizations and some fixes
This commit is contained in:
parent
663d971830
commit
4b6c08110d
@ -635,23 +635,45 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _ceil_div(a, b):
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
def find_period_by_first_row(mat):
|
||||||
|
|
||||||
|
L, _ = mat.shape
|
||||||
|
|
||||||
|
first = mat[0:1]
|
||||||
|
matches = (mat[1:] == first).all(dim=1)
|
||||||
|
candidate_positions = (torch.nonzero(matches).squeeze(-1) + 1).tolist()
|
||||||
|
if isinstance(candidate_positions, int):
|
||||||
|
candidate_positions = [candidate_positions]
|
||||||
|
if not candidate_positions:
|
||||||
|
return L
|
||||||
|
|
||||||
|
for p in sorted(candidate_positions):
|
||||||
|
base = mat[:p]
|
||||||
|
reps = _ceil_div(L, p)
|
||||||
|
tiled = base.repeat(reps, 1)[:L]
|
||||||
|
if torch.equal(tiled, mat):
|
||||||
|
return p
|
||||||
|
|
||||||
|
for p in range(1, L + 1):
|
||||||
|
base = mat[:p]
|
||||||
|
reps = _ceil_div(L, p)
|
||||||
|
tiled = base.repeat(reps, 1)[:L]
|
||||||
|
if torch.equal(tiled, mat):
|
||||||
|
return p
|
||||||
|
|
||||||
|
return L
|
||||||
|
|
||||||
def trim_repeats(expanded):
|
def trim_repeats(expanded):
|
||||||
_, L, D = expanded.shape
|
|
||||||
seq = expanded[0]
|
seq = expanded[0]
|
||||||
|
p_len = find_period_by_first_row(seq)
|
||||||
|
|
||||||
repeat_len = L
|
seq_T = seq.transpose(0, 1)
|
||||||
for k in range(1, L // 2 + 1):
|
p_dim = find_period_by_first_row(seq_T)
|
||||||
if torch.equal(seq[:k], seq[k:2*k]):
|
|
||||||
repeat_len = k
|
|
||||||
break
|
|
||||||
|
|
||||||
repeat_dim = D
|
return expanded[:, :p_len, :p_dim]
|
||||||
for k in range(1, D // 2 + 1):
|
|
||||||
if torch.equal(seq[:, :k], seq[:, k:2*k]):
|
|
||||||
repeat_dim = k
|
|
||||||
break
|
|
||||||
|
|
||||||
return expanded[:, :repeat_len, :repeat_dim]
|
|
||||||
|
|
||||||
class HunyuanVideoFoley(nn.Module):
|
class HunyuanVideoFoley(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -845,11 +867,12 @@ class HunyuanVideoFoley(nn.Module):
|
|||||||
|
|
||||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||||
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
|
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||||
|
|
||||||
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
|
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||||
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
|
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||||
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
|
|
||||||
|
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||||
|
|
||||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||||
|
|
||||||
|
|||||||
@ -211,16 +211,24 @@ class FoleyVae(torch.nn.Module):
|
|||||||
return self.syncformer(x)
|
return self.syncformer(x)
|
||||||
|
|
||||||
def video_encoding(self, video, step: int):
|
def video_encoding(self, video, step: int):
|
||||||
|
t, h, w, c = video.shape
|
||||||
|
|
||||||
if not isinstance(video, torch.Tensor):
|
if not isinstance(video, torch.Tensor):
|
||||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)
|
video = torch.from_numpy(video)
|
||||||
|
|
||||||
video = self.syncformer_preprocess(video).unsqueeze(0)
|
video = video.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0)
|
||||||
seg_len = 16
|
seg_len = 16
|
||||||
t = video.size(1)
|
t = video.size(0)
|
||||||
nseg = max(0, (t - seg_len) // step + 1)
|
nseg = max(0, (t - seg_len) // step + 1)
|
||||||
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
|
stride_t, stride_c, stride_h, stride_w = video.stride()
|
||||||
data = torch.stack(clips, dim=1)
|
|
||||||
|
# no copies
|
||||||
|
data = video.as_strided(
|
||||||
|
size=(nseg, seg_len, c, h, w),
|
||||||
|
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
|
||||||
|
)
|
||||||
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
||||||
|
|
||||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))
|
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))
|
||||||
|
|||||||
@ -26,6 +26,54 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
|
|||||||
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
|
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
|
||||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
|
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
|
||||||
|
|
||||||
|
class CpuLockedTensor(torch.Tensor):
|
||||||
|
def __new__(cls, data):
|
||||||
|
base = torch.as_tensor(data, device='cpu')
|
||||||
|
return torch.Tensor._make_subclass(cls, base, require_grad=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
|
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion
|
||||||
|
def unwrap(x):
|
||||||
|
return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x
|
||||||
|
|
||||||
|
unwrapped_args = torch.utils._pytree.tree_map(unwrap, args)
|
||||||
|
unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
|
result = func(*unwrapped_args, **unwrapped_kwargs)
|
||||||
|
|
||||||
|
# rewrap the resulted tensors
|
||||||
|
if isinstance(result, torch.Tensor):
|
||||||
|
return CpuLockedTensor(result.detach().cpu())
|
||||||
|
elif isinstance(result, (list, tuple)):
|
||||||
|
return type(result)(
|
||||||
|
CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
|
||||||
|
for x in result
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def to(self, *args, allow_gpu=False, **kwargs):
|
||||||
|
if allow_gpu:
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
return self.detach().clone().cpu()
|
||||||
|
|
||||||
|
def cuda(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def pin_memory(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
out = super().detach()
|
||||||
|
return CpuLockedTensor(out)
|
||||||
|
|
||||||
class HunyuanFoleyConditioning(io.ComfyNode):
|
class HunyuanFoleyConditioning(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -53,6 +101,10 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
|||||||
max_d = max([t.size(2) for t in all_])
|
max_d = max([t.size(2) for t in all_])
|
||||||
|
|
||||||
def repeat_shapes(max_value, input, dim = 1):
|
def repeat_shapes(max_value, input, dim = 1):
|
||||||
|
|
||||||
|
if input.shape[dim] == max_value:
|
||||||
|
return input
|
||||||
|
|
||||||
# temporary repeat values on the cpu
|
# temporary repeat values on the cpu
|
||||||
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
||||||
|
|
||||||
@ -61,19 +113,28 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
|||||||
input = input.cpu().repeat(*positions)
|
input = input.cpu().repeat(*positions)
|
||||||
|
|
||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
pad = input[:, :remainder, :]
|
if dim == 1:
|
||||||
input = torch.cat([input, pad], dim =1)
|
pad = input[:, :remainder, :]
|
||||||
|
else:
|
||||||
|
pad = input[:, :, :remainder]
|
||||||
|
input = torch.cat([input, pad], dim = dim)
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
||||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
|
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in
|
||||||
|
(siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)]
|
||||||
|
|
||||||
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
||||||
|
|
||||||
x = siglip_encoding_1
|
x = siglip_encoding_1
|
||||||
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative])
|
||||||
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
.contiguous().view(1, -1, x.size(-1)))
|
||||||
|
negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive])
|
||||||
|
.contiguous().view(1, -1, x.size(-1)))
|
||||||
|
|
||||||
|
negative = [[positive_tensor, {}]]
|
||||||
|
positive = [[negative_tensor, {}]]
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative)
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||||
|
|
||||||
t, c, h, w = video.shape
|
t, c, h, w = video.shape
|
||||||
|
device = video.device
|
||||||
b = 1
|
b = 1
|
||||||
batch_size = b * t
|
batch_size = b * t
|
||||||
|
|
||||||
@ -62,6 +63,8 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
model = vae.first_stage_model if vae is not None else clip_vision.model
|
model = vae.first_stage_model if vae is not None else clip_vision.model
|
||||||
vae = vae if vae is not None else clip_vision
|
vae = vae if vae is not None else clip_vision
|
||||||
|
|
||||||
|
# should be the offload device
|
||||||
|
video = video.cpu()
|
||||||
if hasattr(model, "video_encoding"):
|
if hasattr(model, "video_encoding"):
|
||||||
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||||
batch_size = b * num_segments
|
batch_size = b * num_segments
|
||||||
@ -72,25 +75,31 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
if processing_batch_size != -1:
|
if processing_batch_size != -1:
|
||||||
batch_size = processing_batch_size
|
batch_size = processing_batch_size
|
||||||
|
|
||||||
outputs = []
|
outputs = None
|
||||||
total = data.shape[0]
|
total = data.shape[0]
|
||||||
pbar = comfy.utils.ProgressBar(total/batch_size)
|
pbar = comfy.utils.ProgressBar(total/batch_size)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for i in range(0, total, batch_size):
|
for i in range(0, total, batch_size):
|
||||||
chunk = data[i : i + batch_size]
|
chunk = data[i : i + batch_size].to(device, non_blocking = True)
|
||||||
if hasattr(vae, "encode"):
|
if hasattr(vae, "encode"):
|
||||||
out = vae.encode(chunk)
|
out = vae.encode(chunk)
|
||||||
else:
|
else:
|
||||||
out = vae.encode_image(chunk)
|
out = vae.encode_image(chunk)
|
||||||
out = out["image_embeds"]
|
out = out["image_embeds"]
|
||||||
outputs.append(out)
|
|
||||||
del out, chunk
|
out_cpu = out.cpu()
|
||||||
|
if outputs is None:
|
||||||
|
full_shape = (total, *out_cpu.shape[1:])
|
||||||
|
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
|
||||||
|
|
||||||
|
chunk_len = out_cpu.shape[0]
|
||||||
|
outputs[i : i + chunk_len].copy_(out_cpu)
|
||||||
|
|
||||||
|
del out, chunk, out_cpu
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
output = torch.cat(outputs)
|
return io.NodeOutput(output_fn(outputs))
|
||||||
|
|
||||||
return io.NodeOutput(output_fn(output))
|
|
||||||
|
|
||||||
class ResampleVideo(io.ComfyNode):
|
class ResampleVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user