diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index b8705fb8d..bce9b25fc 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -635,23 +635,45 @@ class SingleStreamBlock(nn.Module): return x +def _ceil_div(a, b): + return (a + b - 1) // b + +def find_period_by_first_row(mat): + + L, _ = mat.shape + + first = mat[0:1] + matches = (mat[1:] == first).all(dim=1) + candidate_positions = (torch.nonzero(matches).squeeze(-1) + 1).tolist() + if isinstance(candidate_positions, int): + candidate_positions = [candidate_positions] + if not candidate_positions: + return L + + for p in sorted(candidate_positions): + base = mat[:p] + reps = _ceil_div(L, p) + tiled = base.repeat(reps, 1)[:L] + if torch.equal(tiled, mat): + return p + + for p in range(1, L + 1): + base = mat[:p] + reps = _ceil_div(L, p) + tiled = base.repeat(reps, 1)[:L] + if torch.equal(tiled, mat): + return p + + return L + def trim_repeats(expanded): - _, L, D = expanded.shape seq = expanded[0] + p_len = find_period_by_first_row(seq) - repeat_len = L - for k in range(1, L // 2 + 1): - if torch.equal(seq[:k], seq[k:2*k]): - repeat_len = k - break + seq_T = seq.transpose(0, 1) + p_dim = find_period_by_first_row(seq_T) - repeat_dim = D - for k in range(1, D // 2 + 1): - if torch.equal(seq[:, :k], seq[:, k:2*k]): - repeat_dim = k - break - - return expanded[:, :repeat_len, :repeat_dim] + return expanded[:, :p_len, :p_dim] class HunyuanVideoFoley(nn.Module): def __init__( @@ -845,11 +867,12 @@ class HunyuanVideoFoley(nn.Module): uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) - cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg) + cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)] + + uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] + uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] - uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True) - uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True) - cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True) + uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index e8567c11b..8e29a05ae 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -211,16 +211,24 @@ class FoleyVae(torch.nn.Module): return self.syncformer(x) def video_encoding(self, video, step: int): + t, h, w, c = video.shape if not isinstance(video, torch.Tensor): - video = torch.from_numpy(video).permute(0, 3, 1, 2) + video = torch.from_numpy(video) - video = self.syncformer_preprocess(video).unsqueeze(0) + video = video.permute(0, 3, 1, 2) + + video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0) seg_len = 16 - t = video.size(1) + t = video.size(0) nseg = max(0, (t - seg_len) // step + 1) - clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)] - data = torch.stack(clips, dim=1) + stride_t, stride_c, stride_h, stride_w = video.stride() + + # no copies + data = video.as_strided( + size=(nseg, seg_len, c, h, w), + stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w), + ) data = rearrange(data, "b s t c h w -> (b s) 1 t c h w") return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0)) diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py index 89eaf2394..e5f168c53 100644 --- a/comfy_extras/nodes_hunyuan_foley.py +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -26,6 +26,54 @@ class EmptyLatentHunyuanFoley(io.ComfyNode): latent = torch.randn(shape, device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, ) +class CpuLockedTensor(torch.Tensor): + def __new__(cls, data): + base = torch.as_tensor(data, device='cpu') + return torch.Tensor._make_subclass(cls, base, require_grad=False) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + + if kwargs is None: + kwargs = {} + + # if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion + def unwrap(x): + return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x + + unwrapped_args = torch.utils._pytree.tree_map(unwrap, args) + unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs) + + result = func(*unwrapped_args, **unwrapped_kwargs) + + # rewrap the resulted tensors + if isinstance(result, torch.Tensor): + return CpuLockedTensor(result.detach().cpu()) + elif isinstance(result, (list, tuple)): + return type(result)( + CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x + for x in result + ) + return result + + def to(self, *args, allow_gpu=False, **kwargs): + if allow_gpu: + return super().to(*args, **kwargs) + return self.detach().clone().cpu() + + def cuda(self, *args, **kwargs): + return self + + def cpu(self): + return self + + def pin_memory(self): + return self + + def detach(self): + out = super().detach() + return CpuLockedTensor(out) + class HunyuanFoleyConditioning(io.ComfyNode): @classmethod def define_schema(cls): @@ -53,6 +101,10 @@ class HunyuanFoleyConditioning(io.ComfyNode): max_d = max([t.size(2) for t in all_]) def repeat_shapes(max_value, input, dim = 1): + + if input.shape[dim] == max_value: + return input + # temporary repeat values on the cpu factor_pos, remainder = divmod(max_value, input.shape[dim]) @@ -61,19 +113,28 @@ class HunyuanFoleyConditioning(io.ComfyNode): input = input.cpu().repeat(*positions) if remainder > 0: - pad = input[:, :remainder, :] - input = torch.cat([input, pad], dim =1) + if dim == 1: + pad = input[:, :remainder, :] + else: + pad = input[:, :, :remainder] + input = torch.cat([input, pad], dim = dim) return input siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_] - siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_] + siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in + (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)] embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0) x = siglip_encoding_1 - negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]] - positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]] + positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative]) + .contiguous().view(1, -1, x.size(-1))) + negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive]) + .contiguous().view(1, -1, x.size(-1))) + + negative = [[positive_tensor, {}]] + positive = [[negative_tensor, {}]] return io.NodeOutput(positive, negative) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index e9509500f..45277c9c4 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -52,6 +52,7 @@ class EncodeVideo(io.ComfyNode): def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): t, c, h, w = video.shape + device = video.device b = 1 batch_size = b * t @@ -62,6 +63,8 @@ class EncodeVideo(io.ComfyNode): model = vae.first_stage_model if vae is not None else clip_vision.model vae = vae if vae is not None else clip_vision + # should be the offload device + video = video.cpu() if hasattr(model, "video_encoding"): data, num_segments, output_fn = model.video_encoding(video, step_size) batch_size = b * num_segments @@ -72,25 +75,31 @@ class EncodeVideo(io.ComfyNode): if processing_batch_size != -1: batch_size = processing_batch_size - outputs = [] + outputs = None total = data.shape[0] pbar = comfy.utils.ProgressBar(total/batch_size) with torch.inference_mode(): for i in range(0, total, batch_size): - chunk = data[i : i + batch_size] + chunk = data[i : i + batch_size].to(device, non_blocking = True) if hasattr(vae, "encode"): out = vae.encode(chunk) else: out = vae.encode_image(chunk) out = out["image_embeds"] - outputs.append(out) - del out, chunk + + out_cpu = out.cpu() + if outputs is None: + full_shape = (total, *out_cpu.shape[1:]) + outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True) + + chunk_len = out_cpu.shape[0] + outputs[i : i + chunk_len].copy_(out_cpu) + + del out, chunk, out_cpu torch.cuda.empty_cache() pbar.update(1) - output = torch.cat(outputs) - - return io.NodeOutput(output_fn(output)) + return io.NodeOutput(output_fn(outputs)) class ResampleVideo(io.ComfyNode): @classmethod