large optimizations and some fixes

This commit is contained in:
Yousef Rafat 2025-10-04 23:33:52 +03:00
parent 663d971830
commit 4b6c08110d
4 changed files with 135 additions and 34 deletions

View File

@ -635,23 +635,45 @@ class SingleStreamBlock(nn.Module):
return x
def _ceil_div(a, b):
return (a + b - 1) // b
def find_period_by_first_row(mat):
L, _ = mat.shape
first = mat[0:1]
matches = (mat[1:] == first).all(dim=1)
candidate_positions = (torch.nonzero(matches).squeeze(-1) + 1).tolist()
if isinstance(candidate_positions, int):
candidate_positions = [candidate_positions]
if not candidate_positions:
return L
for p in sorted(candidate_positions):
base = mat[:p]
reps = _ceil_div(L, p)
tiled = base.repeat(reps, 1)[:L]
if torch.equal(tiled, mat):
return p
for p in range(1, L + 1):
base = mat[:p]
reps = _ceil_div(L, p)
tiled = base.repeat(reps, 1)[:L]
if torch.equal(tiled, mat):
return p
return L
def trim_repeats(expanded):
_, L, D = expanded.shape
seq = expanded[0]
p_len = find_period_by_first_row(seq)
repeat_len = L
for k in range(1, L // 2 + 1):
if torch.equal(seq[:k], seq[k:2*k]):
repeat_len = k
break
seq_T = seq.transpose(0, 1)
p_dim = find_period_by_first_row(seq_T)
repeat_dim = D
for k in range(1, D // 2 + 1):
if torch.equal(seq[:, :k], seq[:, k:2*k]):
repeat_dim = k
break
return expanded[:, :repeat_len, :repeat_dim]
return expanded[:, :p_len, :p_dim]
class HunyuanVideoFoley(nn.Module):
def __init__(
@ -845,11 +867,12 @@ class HunyuanVideoFoley(nn.Module):
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])

View File

@ -211,16 +211,24 @@ class FoleyVae(torch.nn.Module):
return self.syncformer(x)
def video_encoding(self, video, step: int):
t, h, w, c = video.shape
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video).permute(0, 3, 1, 2)
video = torch.from_numpy(video)
video = self.syncformer_preprocess(video).unsqueeze(0)
video = video.permute(0, 3, 1, 2)
video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0)
seg_len = 16
t = video.size(1)
t = video.size(0)
nseg = max(0, (t - seg_len) // step + 1)
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
data = torch.stack(clips, dim=1)
stride_t, stride_c, stride_h, stride_w = video.stride()
# no copies
data = video.as_strided(
size=(nseg, seg_len, c, h, w),
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
)
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))

View File

@ -26,6 +26,54 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
class CpuLockedTensor(torch.Tensor):
def __new__(cls, data):
base = torch.as_tensor(data, device='cpu')
return torch.Tensor._make_subclass(cls, base, require_grad=False)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion
def unwrap(x):
return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x
unwrapped_args = torch.utils._pytree.tree_map(unwrap, args)
unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
result = func(*unwrapped_args, **unwrapped_kwargs)
# rewrap the resulted tensors
if isinstance(result, torch.Tensor):
return CpuLockedTensor(result.detach().cpu())
elif isinstance(result, (list, tuple)):
return type(result)(
CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
for x in result
)
return result
def to(self, *args, allow_gpu=False, **kwargs):
if allow_gpu:
return super().to(*args, **kwargs)
return self.detach().clone().cpu()
def cuda(self, *args, **kwargs):
return self
def cpu(self):
return self
def pin_memory(self):
return self
def detach(self):
out = super().detach()
return CpuLockedTensor(out)
class HunyuanFoleyConditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -53,6 +101,10 @@ class HunyuanFoleyConditioning(io.ComfyNode):
max_d = max([t.size(2) for t in all_])
def repeat_shapes(max_value, input, dim = 1):
if input.shape[dim] == max_value:
return input
# temporary repeat values on the cpu
factor_pos, remainder = divmod(max_value, input.shape[dim])
@ -61,19 +113,28 @@ class HunyuanFoleyConditioning(io.ComfyNode):
input = input.cpu().repeat(*positions)
if remainder > 0:
pad = input[:, :remainder, :]
input = torch.cat([input, pad], dim =1)
if dim == 1:
pad = input[:, :remainder, :]
else:
pad = input[:, :, :remainder]
input = torch.cat([input, pad], dim = dim)
return input
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in
(siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)]
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
x = siglip_encoding_1
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative])
.contiguous().view(1, -1, x.size(-1)))
negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive])
.contiguous().view(1, -1, x.size(-1)))
negative = [[positive_tensor, {}]]
positive = [[negative_tensor, {}]]
return io.NodeOutput(positive, negative)

View File

@ -52,6 +52,7 @@ class EncodeVideo(io.ComfyNode):
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
t, c, h, w = video.shape
device = video.device
b = 1
batch_size = b * t
@ -62,6 +63,8 @@ class EncodeVideo(io.ComfyNode):
model = vae.first_stage_model if vae is not None else clip_vision.model
vae = vae if vae is not None else clip_vision
# should be the offload device
video = video.cpu()
if hasattr(model, "video_encoding"):
data, num_segments, output_fn = model.video_encoding(video, step_size)
batch_size = b * num_segments
@ -72,25 +75,31 @@ class EncodeVideo(io.ComfyNode):
if processing_batch_size != -1:
batch_size = processing_batch_size
outputs = []
outputs = None
total = data.shape[0]
pbar = comfy.utils.ProgressBar(total/batch_size)
with torch.inference_mode():
for i in range(0, total, batch_size):
chunk = data[i : i + batch_size]
chunk = data[i : i + batch_size].to(device, non_blocking = True)
if hasattr(vae, "encode"):
out = vae.encode(chunk)
else:
out = vae.encode_image(chunk)
out = out["image_embeds"]
outputs.append(out)
del out, chunk
out_cpu = out.cpu()
if outputs is None:
full_shape = (total, *out_cpu.shape[1:])
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
chunk_len = out_cpu.shape[0]
outputs[i : i + chunk_len].copy_(out_cpu)
del out, chunk, out_cpu
torch.cuda.empty_cache()
pbar.update(1)
output = torch.cat(outputs)
return io.NodeOutput(output_fn(output))
return io.NodeOutput(output_fn(outputs))
class ResampleVideo(io.ComfyNode):
@classmethod