diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index 588797541..6b7294bbd 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -778,8 +778,6 @@ class HunyuanVideoFoley(nn.Module): self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False) self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False) - self.conditions = None - def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: len = len if len is not None else self.clip_len if bs is None: @@ -858,35 +856,25 @@ class HunyuanVideoFoley(nn.Module): bs, _, ol = x.shape tl = ol // self.patch_size - if self.conditions is None: + def remove_padding(tensor): + mask = tensor.sum(dim=-1) != 0 + out = torch.stack([tensor[b][mask[b]] for b in range(tensor.size(0))], dim=0) + return out - uncondition, condition = torch.chunk(context, 2) + cond_, uncond = torch.chunk(context, 2) + uncond, cond_ = uncond.view(3, -1, self.condition_dim), cond_.view(3, -1, self.condition_dim) + clip_feat, sync_feat, cond_pos = cond_.chunk(3) + uncond_1, uncond_2, cond_neg = uncond.chunk(3) + clip_feat, sync_feat, cond_pos, cond_neg = [remove_padding(t) for t in (clip_feat, sync_feat, cond_pos, cond_neg)] - condition = condition.view(3, context.size(1) // 3, -1) - uncondition = uncondition.view(3, context.size(1) // 3, -1) + diff = cond_pos.shape[1] - cond_neg.shape[1] + if cond_neg.shape[1] < cond_pos.shape[1]: + cond_neg = F.pad(cond_neg, (0, 0, 0, diff)) + elif diff < 0: + cond_pos = F.pad(cond_pos, (0, 0, 0, abs(diff))) - uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) - clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) - cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)] - - uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] - uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] - - uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] - - diff = cond_pos.shape[1] - cond_neg.shape[1] - if cond_neg.shape[1] < cond_pos.shape[1]: - cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff)) - elif diff < 0: - cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff))) - - clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) - clip_feat = clip_feat.view(2, -1, 768) - - self.conditions = (clip_feat, sync_feat, cond) - - else: - clip_feat, sync_feat, cond = self.conditions + clip_feat, sync_feat, cond = \ + torch.cat([uncond_1[:, :clip_feat.size(1), :], clip_feat]), torch.cat([uncond_2[:, :sync_feat.size(1), :], sync_feat]), torch.cat([cond_neg, cond_pos]) if drop_visual is not None: clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index 7c634bce0..387f46fd7 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -213,7 +213,9 @@ class FoleyVae(torch.nn.Module): v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) - self.decode_sample_rate = self.dac.sample_rate + + def decode_sample_rate(self): + return self.dac.sample_rate def decode(self, x, vae_options = {}): return self.dac.decode(x) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 9acfde78b..309ac77db 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -88,7 +88,10 @@ class VAEDecodeAudio: std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - sample_rate = vae.first_stage_model.decode_sample_rate or 44100 + if hasattr(vae.first_stage_model, "decode_sample_rate"): + sample_rate = vae.first_stage_model.decode_sample_rate() + else: + sample_rate = 44100 return ({"waveform": audio, "sample_rate": sample_rate}, ) diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py index af914d9bf..9d6625ee2 100644 --- a/comfy_extras/nodes_hunyuan_foley.py +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -1,5 +1,6 @@ import torch import comfy.model_management +import torch.nn.functional as F from typing_extensions import override from comfy_api.latest import ComfyExtension, io @@ -20,60 +21,13 @@ class EmptyLatentHunyuanFoley(io.ComfyNode): @classmethod def execute(cls, length, batch_size, video = None): if video is not None: + video = video.images length = video.size(0) length /= 25 shape = (batch_size, 128, int(50 * length)) latent = torch.randn(shape, device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, ) -class CpuLockedTensor(torch.Tensor): - def __new__(cls, data): - base = torch.as_tensor(data, device='cpu') - return torch.Tensor._make_subclass(cls, base, require_grad=False) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - - if kwargs is None: - kwargs = {} - - # if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion - def unwrap(x): - return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x - - unwrapped_args = torch.utils._pytree.tree_map(unwrap, args) - unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs) - - result = func(*unwrapped_args, **unwrapped_kwargs) - - # rewrap the resulted tensors - if isinstance(result, torch.Tensor): - return CpuLockedTensor(result.detach().cpu()) - elif isinstance(result, (list, tuple)): - return type(result)( - CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x - for x in result - ) - return result - - def to(self, *args, allow_gpu=False, **kwargs): - if allow_gpu: - return super().to(*args, **kwargs) - return self.detach().clone().cpu() - - def cuda(self, *args, **kwargs): - return self - - def cpu(self): - return self - - def pin_memory(self): - return self - - def detach(self): - out = super().detach() - return CpuLockedTensor(out) - class HunyuanFoleyConditioning(io.ComfyNode): @classmethod def define_schema(cls): @@ -96,45 +50,14 @@ class HunyuanFoleyConditioning(io.ComfyNode): text_encoding_positive = text_encoding_positive[0][0] text_encoding_negative = text_encoding_negative[0][0] all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative) - - max_l = max([t.size(1) for t in all_]) - max_d = max([t.size(2) for t in all_]) - - def repeat_shapes(max_value, input, dim = 1): - - if input.shape[dim] == max_value: - return input - - # temporary repeat values on the cpu - factor_pos, remainder = divmod(max_value, input.shape[dim]) - - positions = [1] * input.ndim - positions[dim] = factor_pos - input = input.cpu().repeat(*positions) - - if remainder > 0: - if dim == 1: - pad = input[:, :remainder, :] - else: - pad = input[:, :, :remainder] - input = torch.cat([input, pad], dim = dim) - - return input - - siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_] - siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in - (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)] - - embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0) - - x = siglip_encoding_1 - positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative]) - .contiguous().view(1, -1, x.size(-1))) - negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive]) - .contiguous().view(1, -1, x.size(-1))) - - negative = [[positive_tensor, {}]] - positive = [[negative_tensor, {}]] + biggest = max([t.size(1) for t in all_]) + siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [ + F.pad(t, (0, 0, 0, biggest - t.size(1), 0, 0)) for t in all_ + ] + positive_tensor = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding_positive]) + negative_tensor = torch.cat([torch.zeros_like(siglip_encoding_1), torch.zeros_like(synchformer_encoding_2), text_encoding_negative]) + negative = [[positive_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]] + positive = [[negative_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]] return io.NodeOutput(positive, negative) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 3daf51e96..001013d17 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -50,7 +50,8 @@ class EncodeVideo(io.ComfyNode): @classmethod def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): - + + video = video.images if not isinstance(video, torch.Tensor): video = torch.from_numpy(video) @@ -135,6 +136,8 @@ class ResampleVideo(io.ComfyNode): @classmethod def execute(cls, video, target_fps: int): # doesn't support upsampling + + video_components = video.get_components() with av.open(video.get_stream_source(), mode="r") as container: stream = container.streams.video[0] frames = [] @@ -147,11 +150,7 @@ class ResampleVideo(io.ComfyNode): # yield original frames if asked for upsampling if target_fps > src_fps: - for packet in container.demux(stream): - for frame in packet.decode(): - arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() - frames.append(arr) - return io.NodeOutput(torch.stack(frames)) + return io.NodeOutput(video_components) stream.thread_type = "AUTO" @@ -168,25 +167,13 @@ class ResampleVideo(io.ComfyNode): frames.append(arr) next_time += step - return io.NodeOutput(torch.stack(frames)) - -class VideoToImage(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="VideoToImage", - category="image/video", - display_name = "Video To Images", - inputs=[io.Video.Input("video")], - outputs=[io.Image.Output("images")] - ) - @classmethod - def execute(cls, video): - with av.open(video.get_stream_source(), mode="r") as container: - components = video.get_components_internal(container) - - images = components.images - return io.NodeOutput(images) + new_components = VideoComponents( + images=torch.stack(frames), + audio=video_components.audio, + frame_rate=Fraction(target_fps, 1), + metadata=video_components.metadata, + ) + return io.NodeOutput(new_components) class SaveWEBM(io.ComfyNode): @classmethod @@ -388,7 +375,6 @@ class VideoExtension(ComfyExtension): LoadVideo, EncodeVideo, ResampleVideo, - VideoToImage ] async def comfy_entrypoint() -> VideoExtension: