mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
updated based on feedback
This commit is contained in:
parent
25f7bbed78
commit
86348baa9e
@ -778,8 +778,6 @@ class HunyuanVideoFoley(nn.Module):
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
|
||||
|
||||
self.conditions = None
|
||||
|
||||
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
|
||||
len = len if len is not None else self.clip_len
|
||||
if bs is None:
|
||||
@ -858,35 +856,25 @@ class HunyuanVideoFoley(nn.Module):
|
||||
bs, _, ol = x.shape
|
||||
tl = ol // self.patch_size
|
||||
|
||||
if self.conditions is None:
|
||||
def remove_padding(tensor):
|
||||
mask = tensor.sum(dim=-1) != 0
|
||||
out = torch.stack([tensor[b][mask[b]] for b in range(tensor.size(0))], dim=0)
|
||||
return out
|
||||
|
||||
uncondition, condition = torch.chunk(context, 2)
|
||||
cond_, uncond = torch.chunk(context, 2)
|
||||
uncond, cond_ = uncond.view(3, -1, self.condition_dim), cond_.view(3, -1, self.condition_dim)
|
||||
clip_feat, sync_feat, cond_pos = cond_.chunk(3)
|
||||
uncond_1, uncond_2, cond_neg = uncond.chunk(3)
|
||||
clip_feat, sync_feat, cond_pos, cond_neg = [remove_padding(t) for t in (clip_feat, sync_feat, cond_pos, cond_neg)]
|
||||
|
||||
condition = condition.view(3, context.size(1) // 3, -1)
|
||||
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||
diff = cond_pos.shape[1] - cond_neg.shape[1]
|
||||
if cond_neg.shape[1] < cond_pos.shape[1]:
|
||||
cond_neg = F.pad(cond_neg, (0, 0, 0, diff))
|
||||
elif diff < 0:
|
||||
cond_pos = F.pad(cond_pos, (0, 0, 0, abs(diff)))
|
||||
|
||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||
|
||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
diff = cond_pos.shape[1] - cond_neg.shape[1]
|
||||
if cond_neg.shape[1] < cond_pos.shape[1]:
|
||||
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
|
||||
elif diff < 0:
|
||||
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
clip_feat = clip_feat.view(2, -1, 768)
|
||||
|
||||
self.conditions = (clip_feat, sync_feat, cond)
|
||||
|
||||
else:
|
||||
clip_feat, sync_feat, cond = self.conditions
|
||||
clip_feat, sync_feat, cond = \
|
||||
torch.cat([uncond_1[:, :clip_feat.size(1), :], clip_feat]), torch.cat([uncond_2[:, :sync_feat.size(1), :], sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
|
||||
if drop_visual is not None:
|
||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||
|
||||
@ -213,7 +213,9 @@ class FoleyVae(torch.nn.Module):
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
self.decode_sample_rate = self.dac.sample_rate
|
||||
|
||||
def decode_sample_rate(self):
|
||||
return self.dac.sample_rate
|
||||
|
||||
def decode(self, x, vae_options = {}):
|
||||
return self.dac.decode(x)
|
||||
|
||||
@ -88,7 +88,10 @@ class VAEDecodeAudio:
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
sample_rate = vae.first_stage_model.decode_sample_rate or 44100
|
||||
if hasattr(vae.first_stage_model, "decode_sample_rate"):
|
||||
sample_rate = vae.first_stage_model.decode_sample_rate()
|
||||
else:
|
||||
sample_rate = 44100
|
||||
return ({"waveform": audio, "sample_rate": sample_rate}, )
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
@ -20,60 +21,13 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, length, batch_size, video = None):
|
||||
if video is not None:
|
||||
video = video.images
|
||||
length = video.size(0)
|
||||
length /= 25
|
||||
shape = (batch_size, 128, int(50 * length))
|
||||
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
|
||||
|
||||
class CpuLockedTensor(torch.Tensor):
|
||||
def __new__(cls, data):
|
||||
base = torch.as_tensor(data, device='cpu')
|
||||
return torch.Tensor._make_subclass(cls, base, require_grad=False)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion
|
||||
def unwrap(x):
|
||||
return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x
|
||||
|
||||
unwrapped_args = torch.utils._pytree.tree_map(unwrap, args)
|
||||
unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
|
||||
|
||||
result = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
# rewrap the resulted tensors
|
||||
if isinstance(result, torch.Tensor):
|
||||
return CpuLockedTensor(result.detach().cpu())
|
||||
elif isinstance(result, (list, tuple)):
|
||||
return type(result)(
|
||||
CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
|
||||
for x in result
|
||||
)
|
||||
return result
|
||||
|
||||
def to(self, *args, allow_gpu=False, **kwargs):
|
||||
if allow_gpu:
|
||||
return super().to(*args, **kwargs)
|
||||
return self.detach().clone().cpu()
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
out = super().detach()
|
||||
return CpuLockedTensor(out)
|
||||
|
||||
class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -96,45 +50,14 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
text_encoding_positive = text_encoding_positive[0][0]
|
||||
text_encoding_negative = text_encoding_negative[0][0]
|
||||
all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)
|
||||
|
||||
max_l = max([t.size(1) for t in all_])
|
||||
max_d = max([t.size(2) for t in all_])
|
||||
|
||||
def repeat_shapes(max_value, input, dim = 1):
|
||||
|
||||
if input.shape[dim] == max_value:
|
||||
return input
|
||||
|
||||
# temporary repeat values on the cpu
|
||||
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
||||
|
||||
positions = [1] * input.ndim
|
||||
positions[dim] = factor_pos
|
||||
input = input.cpu().repeat(*positions)
|
||||
|
||||
if remainder > 0:
|
||||
if dim == 1:
|
||||
pad = input[:, :remainder, :]
|
||||
else:
|
||||
pad = input[:, :, :remainder]
|
||||
input = torch.cat([input, pad], dim = dim)
|
||||
|
||||
return input
|
||||
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in
|
||||
(siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)]
|
||||
|
||||
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
||||
|
||||
x = siglip_encoding_1
|
||||
positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative])
|
||||
.contiguous().view(1, -1, x.size(-1)))
|
||||
negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive])
|
||||
.contiguous().view(1, -1, x.size(-1)))
|
||||
|
||||
negative = [[positive_tensor, {}]]
|
||||
positive = [[negative_tensor, {}]]
|
||||
biggest = max([t.size(1) for t in all_])
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [
|
||||
F.pad(t, (0, 0, 0, biggest - t.size(1), 0, 0)) for t in all_
|
||||
]
|
||||
positive_tensor = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding_positive])
|
||||
negative_tensor = torch.cat([torch.zeros_like(siglip_encoding_1), torch.zeros_like(synchformer_encoding_2), text_encoding_negative])
|
||||
negative = [[positive_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]]
|
||||
positive = [[negative_tensor.view(1, -1, siglip_encoding_1.size(-1)), {}]]
|
||||
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
@ -50,7 +50,8 @@ class EncodeVideo(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||
|
||||
|
||||
video = video.images
|
||||
if not isinstance(video, torch.Tensor):
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
@ -135,6 +136,8 @@ class ResampleVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, video, target_fps: int):
|
||||
# doesn't support upsampling
|
||||
|
||||
video_components = video.get_components()
|
||||
with av.open(video.get_stream_source(), mode="r") as container:
|
||||
stream = container.streams.video[0]
|
||||
frames = []
|
||||
@ -147,11 +150,7 @@ class ResampleVideo(io.ComfyNode):
|
||||
|
||||
# yield original frames if asked for upsampling
|
||||
if target_fps > src_fps:
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
|
||||
frames.append(arr)
|
||||
return io.NodeOutput(torch.stack(frames))
|
||||
return io.NodeOutput(video_components)
|
||||
|
||||
stream.thread_type = "AUTO"
|
||||
|
||||
@ -168,25 +167,13 @@ class ResampleVideo(io.ComfyNode):
|
||||
frames.append(arr)
|
||||
next_time += step
|
||||
|
||||
return io.NodeOutput(torch.stack(frames))
|
||||
|
||||
class VideoToImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VideoToImage",
|
||||
category="image/video",
|
||||
display_name = "Video To Images",
|
||||
inputs=[io.Video.Input("video")],
|
||||
outputs=[io.Image.Output("images")]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, video):
|
||||
with av.open(video.get_stream_source(), mode="r") as container:
|
||||
components = video.get_components_internal(container)
|
||||
|
||||
images = components.images
|
||||
return io.NodeOutput(images)
|
||||
new_components = VideoComponents(
|
||||
images=torch.stack(frames),
|
||||
audio=video_components.audio,
|
||||
frame_rate=Fraction(target_fps, 1),
|
||||
metadata=video_components.metadata,
|
||||
)
|
||||
return io.NodeOutput(new_components)
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -388,7 +375,6 @@ class VideoExtension(ComfyExtension):
|
||||
LoadVideo,
|
||||
EncodeVideo,
|
||||
ResampleVideo,
|
||||
VideoToImage
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user