mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Support HunyuanVideo15 latent resampler
This commit is contained in:
parent
dc2e308422
commit
18ae40065a
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
||||||
|
import model_management, model_patcher
|
||||||
|
|
||||||
|
class SRResidualCausalBlock3D(nn.Module):
|
||||||
|
def __init__(self, channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
VideoConv3d(channels, channels, kernel_size=3),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
VideoConv3d(channels, channels, kernel_size=3),
|
||||||
|
nn.SiLU(inplace=True),
|
||||||
|
VideoConv3d(channels, channels, kernel_size=3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + self.block(x)
|
||||||
|
|
||||||
|
class SRModel3DV2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
hidden_channels: int = 64,
|
||||||
|
num_blocks: int = 6,
|
||||||
|
global_residual: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||||
|
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||||
|
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||||
|
self.global_residual = bool(global_residual)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
y = self.in_conv(x)
|
||||||
|
for blk in self.blocks:
|
||||||
|
y = blk(y)
|
||||||
|
y = self.out_conv(y)
|
||||||
|
if self.global_residual and (y.shape == residual.shape):
|
||||||
|
y = y + residual
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class Upsampler(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
z_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
block_out_channels: tuple[int, ...],
|
||||||
|
num_res_blocks: int = 2,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.z_channels = z_channels
|
||||||
|
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||||
|
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_shortcut=False,
|
||||||
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
for j in range(num_res_blocks + 1)])
|
||||||
|
ch = tgt
|
||||||
|
self.up.append(stage)
|
||||||
|
|
||||||
|
self.norm_out = RMS_norm(ch)
|
||||||
|
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
z: (B, C, T, H, W)
|
||||||
|
target_shape: (H, W)
|
||||||
|
"""
|
||||||
|
# z to block_in
|
||||||
|
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||||
|
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for stage in self.up:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
|
||||||
|
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
UPSAMPLERS = {
|
||||||
|
"720p": SRModel3DV2,
|
||||||
|
"1080p": Upsampler,
|
||||||
|
}
|
||||||
|
|
||||||
|
class HunyuanVideo15SRModel():
|
||||||
|
def __init__(self, model_type, config):
|
||||||
|
self.load_device = model_management.vae_device()
|
||||||
|
offload_device = model_management.vae_offload_device()
|
||||||
|
self.dtype = model_management.vae_dtype(self.load_device)
|
||||||
|
self.model_class = UPSAMPLERS.get(model_type)
|
||||||
|
self.model = self.model_class(**config).eval()
|
||||||
|
|
||||||
|
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.model.load_state_dict(sd, strict=True)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
|
def resample_latent(self, latent):
|
||||||
|
model_management.load_model_gpu(self.patcher)
|
||||||
|
return self.model(latent.to(self.load_device))
|
||||||
@ -4,7 +4,8 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -169,6 +170,93 @@ class HunyuanVideo15RefinerLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(positive, negative, latent)
|
return io.NodeOutput(positive, negative, latent)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentUpscaleModelLoader(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LatentUpscaleModelLoader",
|
||||||
|
display_name="Load Latent Upscale Model",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.UpscaleModel.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
|
|
||||||
|
if "blocks.0.block.0.conv.weight" in sd:
|
||||||
|
config = {
|
||||||
|
"in_channels": sd["in_conv.conv.weight"].shape[1],
|
||||||
|
"out_channels": sd["out_conv.conv.weight"].shape[0],
|
||||||
|
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
|
||||||
|
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
|
||||||
|
"global_residual": False,
|
||||||
|
}
|
||||||
|
model_type = "720p"
|
||||||
|
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||||
|
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||||
|
config = {
|
||||||
|
"z_channels": sd["conv_in.conv.weight"].shape[1],
|
||||||
|
"out_channels": sd["conv_out.conv.weight"].shape[0],
|
||||||
|
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||||
|
}
|
||||||
|
model_type = "1080p"
|
||||||
|
|
||||||
|
model = HunyuanVideo15SRModel(model_type, config)
|
||||||
|
model.load_sd(sd)
|
||||||
|
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
load_model = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="HunyuanVideo15LatentUpscaleWithModel",
|
||||||
|
display_name="Hunyuan Video 15 Latent Upscale With Model",
|
||||||
|
category="latent",
|
||||||
|
inputs=[
|
||||||
|
io.UpscaleModel.Input("model"),
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
|
||||||
|
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
|
||||||
|
io.Int.Input("height", default=720, min=0, max=16384, step=8),
|
||||||
|
io.Combo.Input("crop", options=["disabled", "center"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
|
||||||
|
if width == 0 and height == 0:
|
||||||
|
return io.NodeOutput(samples)
|
||||||
|
else:
|
||||||
|
if width == 0:
|
||||||
|
height = max(64, height)
|
||||||
|
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||||
|
elif height == 0:
|
||||||
|
width = max(64, width)
|
||||||
|
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||||
|
else:
|
||||||
|
width = max(64, width)
|
||||||
|
height = max(64, height)
|
||||||
|
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
|
||||||
|
s = model.resample_latent(s)
|
||||||
|
return io.NodeOutput({"samples": s.cpu().float()})
|
||||||
|
|
||||||
|
upscale = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
"1. The main content and theme of the video."
|
"1. The main content and theme of the video."
|
||||||
@ -325,6 +413,8 @@ class HunyuanExtension(ComfyExtension):
|
|||||||
EmptyHunyuanVideo15Latent,
|
EmptyHunyuanVideo15Latent,
|
||||||
HunyuanVideo15ImageToVideo,
|
HunyuanVideo15ImageToVideo,
|
||||||
HunyuanVideo15RefinerLatent,
|
HunyuanVideo15RefinerLatent,
|
||||||
|
HunyuanVideo15LatentUpscaleWithModel,
|
||||||
|
LatentUpscaleModelLoader,
|
||||||
HunyuanImageToVideo,
|
HunyuanImageToVideo,
|
||||||
EmptyHunyuanImageLatent,
|
EmptyHunyuanImageLatent,
|
||||||
HunyuanRefinerLatent,
|
HunyuanRefinerLatent,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user