updated based on feedback

This commit is contained in:
Yousef Rafat 2025-11-19 23:01:17 +02:00
parent 25f7bbed78
commit 86348baa9e
5 changed files with 45 additions and 143 deletions

View File

@ -778,8 +778,6 @@ class HunyuanVideoFoley(nn.Module):
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
self.conditions = None
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.clip_len
if bs is None:
@ -858,35 +856,25 @@ class HunyuanVideoFoley(nn.Module):
bs, _, ol = x.shape
tl = ol // self.patch_size
if self.conditions is None:
def remove_padding(tensor):
mask = tensor.sum(dim=-1) != 0
out = torch.stack([tensor[b][mask[b]] for b in range(tensor.size(0))], dim=0)
return out
uncondition, condition = torch.chunk(context, 2)
cond_, uncond = torch.chunk(context, 2)
uncond, cond_ = uncond.view(3, -1, self.condition_dim), cond_.view(3, -1, self.condition_dim)
clip_feat, sync_feat, cond_pos = cond_.chunk(3)
uncond_1, uncond_2, cond_neg = uncond.chunk(3)
clip_feat, sync_feat, cond_pos, cond_neg = [remove_padding(t) for t in (clip_feat, sync_feat, cond_pos, cond_neg)]
condition = condition.view(3, context.size(1) // 3, -1)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
diff = cond_pos.shape[1] - cond_neg.shape[1]
if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = F.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = F.pad(cond_pos, (0, 0, 0, abs(diff)))
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
diff = cond_pos.shape[1] - cond_neg.shape[1]
if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
clip_feat = clip_feat.view(2, -1, 768)
self.conditions = (clip_feat, sync_feat, cond)
else:
clip_feat, sync_feat, cond = self.conditions
clip_feat, sync_feat, cond = \
torch.cat([uncond_1[:, :clip_feat.size(1), :], clip_feat]), torch.cat([uncond_2[:, :sync_feat.size(1), :], sync_feat]), torch.cat([cond_neg, cond_pos])
if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)

View File

@ -213,7 +213,9 @@ class FoleyVae(torch.nn.Module):
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
self.decode_sample_rate = self.dac.sample_rate
def decode_sample_rate(self):
return self.dac.sample_rate
def decode(self, x, vae_options = {}):
return self.dac.decode(x)

View File

@ -88,7 +88,10 @@ class VAEDecodeAudio:
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
sample_rate = vae.first_stage_model.decode_sample_rate or 44100
if hasattr(vae.first_stage_model, "decode_sample_rate"):
sample_rate = vae.first_stage_model.decode_sample_rate()
else:
sample_rate = 44100
return ({"waveform": audio, "sample_rate": sample_rate}, )

View File

@ -1,5 +1,6 @@
import torch
import comfy.model_management
import torch.nn.functional as F
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
@ -20,60 +21,13 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
@classmethod
def execute(cls, length, batch_size, video = None):
if video is not None:
video = video.images
length = video.size(0)
length /= 25
shape = (batch_size, 128, int(50 * length))
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
class CpuLockedTensor(torch.Tensor):
def __new__(cls, data):
base = torch.as_tensor(data, device='cpu')
return torch.Tensor._make_subclass(cls, base, require_grad=False)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion
def unwrap(x):
return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x
unwrapped_args = torch.utils._pytree.tree_map(unwrap, args)
unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
result = func(*unwrapped_args, **unwrapped_kwargs)
# rewrap the resulted tensors
if isinstance(result, torch.Tensor):
return CpuLockedTensor(result.detach().cpu())
elif isinstance(result, (list, tuple)):
return type(result)(
CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
for x in result
)
return result
def to(self, *args, allow_gpu=False, **kwargs):
if allow_gpu:
return super().to(*args, **kwargs)
return self.detach().clone().cpu()
def cuda(self, *args, **kwargs):
return self
def cpu(self):
return self
def pin_memory(self):
return self
def detach(self):
out = super().detach()
return CpuLockedTensor(out)
class HunyuanFoleyConditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -96,45 +50,14 @@ class HunyuanFoleyConditioning(io.ComfyNode):
text_encoding_positive = text_encoding_positive[0][0]
text_encoding_negative = text_encoding_negative[0][0]
all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)
max_l = max([t.size(1) for t in all_])
max_d = max([t.size(2) for t in all_])
def repeat_shapes(max_value, input, dim = 1):
if input.shape[dim] == max_value:
return input
# temporary repeat values on the cpu
factor_pos, remainder = divmod(max_value, input.shape[dim])
positions = [1] * input.ndim
positions[dim] = factor_pos
input = input.cpu().repeat(*positions)
if remainder > 0:
if dim == 1:
pad = input[:, :remainder, :]
else:
pad = input[:, :, :remainder]
input = torch.cat([input, pad], dim = dim)
return input
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in
(siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)]
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
x = siglip_encoding_1
positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative])
.contiguous().view(1, -1, x.size(-1)))
negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive])
.contiguous().view(1, -1, x.size(-1)))
negative = [[positive_tensor, {}]]
positive = [[negative_tensor, {}]]
biggest = max([t.size(1) for t in all_])
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [
F.pad(t, (0, 0, 0, biggest - t.size(1), 0, 0)) for t in all_
]
positive_tensor = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding_positive])
negative_tensor = torch.cat([torch.zeros_like(siglip_encoding_1), torch.zeros_like(synchformer_encoding_2), text_encoding_negative])
negative = [[positive_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]]
positive = [[negative_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]]
return io.NodeOutput(positive, negative)

View File

@ -50,7 +50,8 @@ class EncodeVideo(io.ComfyNode):
@classmethod
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
video = video.images
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video)
@ -135,6 +136,8 @@ class ResampleVideo(io.ComfyNode):
@classmethod
def execute(cls, video, target_fps: int):
# doesn't support upsampling
video_components = video.get_components()
with av.open(video.get_stream_source(), mode="r") as container:
stream = container.streams.video[0]
frames = []
@ -147,11 +150,7 @@ class ResampleVideo(io.ComfyNode):
# yield original frames if asked for upsampling
if target_fps > src_fps:
for packet in container.demux(stream):
for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
frames.append(arr)
return io.NodeOutput(torch.stack(frames))
return io.NodeOutput(video_components)
stream.thread_type = "AUTO"
@ -168,25 +167,13 @@ class ResampleVideo(io.ComfyNode):
frames.append(arr)
next_time += step
return io.NodeOutput(torch.stack(frames))
class VideoToImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoToImage",
category="image/video",
display_name = "Video To Images",
inputs=[io.Video.Input("video")],
outputs=[io.Image.Output("images")]
)
@classmethod
def execute(cls, video):
with av.open(video.get_stream_source(), mode="r") as container:
components = video.get_components_internal(container)
images = components.images
return io.NodeOutput(images)
new_components = VideoComponents(
images=torch.stack(frames),
audio=video_components.audio,
frame_rate=Fraction(target_fps, 1),
metadata=video_components.metadata,
)
return io.NodeOutput(new_components)
class SaveWEBM(io.ComfyNode):
@classmethod
@ -388,7 +375,6 @@ class VideoExtension(ComfyExtension):
LoadVideo,
EncodeVideo,
ResampleVideo,
VideoToImage
]
async def comfy_entrypoint() -> VideoExtension: