Support HunyuanVideo15 latent resampler

This commit is contained in:
kijai 2025-11-19 17:59:41 +02:00 committed by comfyanonymous
parent dc2e308422
commit 18ae40065a
2 changed files with 211 additions and 1 deletions

View File

@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
self.up = nn.ModuleList()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)
out = self.conv_out(F.silu(self.norm_out(x)))
return out
UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@ -4,7 +4,8 @@ import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
import folder_paths
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
@ -169,6 +170,93 @@ class HunyuanVideo15RefinerLatent(io.ComfyNode):
return io.NodeOutput(positive, negative, latent)
class LatentUpscaleModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentUpscaleModelLoader",
display_name="Load Latent Upscale Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "blocks.0.block.0.conv.weight" in sd:
config = {
"in_channels": sd["in_conv.conv.weight"].shape[1],
"out_channels": sd["out_conv.conv.weight"].shape[0],
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
"global_residual": False,
}
model_type = "720p"
elif "up.0.block.0.conv1.conv.weight" in sd:
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
config = {
"z_channels": sd["conv_in.conv.weight"].shape[1],
"out_channels": sd["conv_out.conv.weight"].shape[0],
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
}
model_type = "1080p"
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
return io.NodeOutput(model)
load_model = execute # TODO: remove
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15LatentUpscaleWithModel",
display_name="Hunyuan Video 15 Latent Upscale With Model",
category="latent",
inputs=[
io.UpscaleModel.Input("model"),
io.Latent.Input("samples"),
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
io.Int.Input("height", default=720, min=0, max=16384, step=8),
io.Combo.Input("crop", options=["disabled", "center"]),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
if width == 0 and height == 0:
return io.NodeOutput(samples)
else:
if width == 0:
height = max(64, height)
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
elif height == 0:
width = max(64, width)
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
else:
width = max(64, width)
height = max(64, height)
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
s = model.resample_latent(s)
return io.NodeOutput({"samples": s.cpu().float()})
upscale = execute # TODO: remove
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
@ -325,6 +413,8 @@ class HunyuanExtension(ComfyExtension):
EmptyHunyuanVideo15Latent,
HunyuanVideo15ImageToVideo,
HunyuanVideo15RefinerLatent,
HunyuanVideo15LatentUpscaleWithModel,
LatentUpscaleModelLoader,
HunyuanImageToVideo,
EmptyHunyuanImageLatent,
HunyuanRefinerLatent,