mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 15:32:35 +08:00
some fixes in model loading and nodes
This commit is contained in:
parent
ab01aceaa2
commit
cc3a1389ad
@ -5,7 +5,7 @@ from typing import List
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
|
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
|
||||||
|
|
||||||
@ -96,12 +96,6 @@ class DACEncoder(nn.Module):
|
|||||||
d_model *= 2
|
d_model *= 2
|
||||||
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
|
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
|
||||||
|
|
||||||
# Create last convolution
|
|
||||||
self.block += [
|
|
||||||
Snake1d(d_model),
|
|
||||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Wrap black into nn.Sequential
|
# Wrap black into nn.Sequential
|
||||||
self.block = nn.Sequential(*self.block)
|
self.block = nn.Sequential(*self.block)
|
||||||
self.enc_dim = d_model
|
self.enc_dim = d_model
|
||||||
@ -151,12 +145,6 @@ class DACDecoder(nn.Module):
|
|||||||
output_dim = channels // 2 ** (i + 1)
|
output_dim = channels // 2 ** (i + 1)
|
||||||
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
|
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
|
||||||
|
|
||||||
layers += [
|
|
||||||
Snake1d(output_dim, device = device, dtype = dtype),
|
|
||||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
|
|
||||||
nn.Tanh(),
|
|
||||||
]
|
|
||||||
|
|
||||||
self.model = nn.Sequential(*layers)
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -1319,6 +1319,10 @@ class HunyuanFoley(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.clap_model.ClapLargeTokenizer, comfy.text_encoders.clap_model.ClapTextEncoderModel)
|
return supported_models_base.ClipTarget(comfy.text_encoders.clap_model.ClapLargeTokenizer, comfy.text_encoders.clap_model.ClapTextEncoderModel)
|
||||||
|
|
||||||
|
def process_clip_state_dict(self, state_dict):
|
||||||
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "transformer." for k in self.text_encoder_key_prefix}, filter_keys=True)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
class QwenImage(supported_models_base.BASE):
|
class QwenImage(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "qwen_image",
|
"image_model": "qwen_image",
|
||||||
|
|||||||
@ -34,8 +34,8 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
|||||||
display_name="HunyuanFoleyConditioning",
|
display_name="HunyuanFoleyConditioning",
|
||||||
category="conditioning/video_models",
|
category="conditioning/video_models",
|
||||||
inputs = [
|
inputs = [
|
||||||
io.Conditioning.Input("video_encoding_1"),
|
io.Conditioning.Input("siglip_encoding_1"),
|
||||||
io.Conditioning.Input("video_encoding_2"),
|
io.Conditioning.Input("synchformer_encoding_2"),
|
||||||
io.Conditioning.Input("text_encoding"),
|
io.Conditioning.Input("text_encoding"),
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
|||||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||||
from comfy_api.latest import ComfyExtension, io, ui
|
from comfy_api.latest import ComfyExtension, io, ui
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class EncodeVideo(io.ComfyNode):
|
class EncodeVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -49,6 +50,7 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||||
|
|
||||||
t, c, h, w = video.shape
|
t, c, h, w = video.shape
|
||||||
b = 1
|
b = 1
|
||||||
batch_size = b * t
|
batch_size = b * t
|
||||||
@ -71,10 +73,15 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
total = data.shape[0]
|
total = data.shape[0]
|
||||||
for i in range(0, total, batch_size):
|
pbar = comfy.utils.ProgressBar(total/batch_size)
|
||||||
chunk = data[i : i + batch_size]
|
with torch.inference_mode():
|
||||||
out = vae.encode(chunk)
|
for i in range(0, total, batch_size):
|
||||||
outputs.append(out)
|
chunk = data[i : i + batch_size]
|
||||||
|
out = vae.encode(chunk)
|
||||||
|
outputs.append(out)
|
||||||
|
del out, chunk
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
output = torch.cat(outputs)
|
output = torch.cat(outputs)
|
||||||
|
|
||||||
@ -109,7 +116,7 @@ class ResampleVideo(io.ComfyNode):
|
|||||||
for frame in packet.decode():
|
for frame in packet.decode():
|
||||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
||||||
frames.append(arr)
|
frames.append(arr)
|
||||||
return torch.stack(frames)
|
return io.NodeOutput(torch.stack(frames))
|
||||||
|
|
||||||
stream.thread_type = "AUTO"
|
stream.thread_type = "AUTO"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user