mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
Current behavior: comfy/ldm/hunyuan_video/upsampler.py imports model_management directly as top-level modules. Double initialization and log spamming. Fix: Changed imports to absolute imports (from comfy import ...) to align with the rest of the codebase and ensure the singleton instances are used correctly.
122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
|
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
|
from comfy import model_management, model_patcher
|
|
|
|
class SRResidualCausalBlock3D(nn.Module):
|
|
def __init__(self, channels: int):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
nn.SiLU(inplace=True),
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
nn.SiLU(inplace=True),
|
|
VideoConv3d(channels, channels, kernel_size=3),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.block(x)
|
|
|
|
class SRModel3DV2(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
hidden_channels: int = 64,
|
|
num_blocks: int = 6,
|
|
global_residual: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
|
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
|
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
|
self.global_residual = bool(global_residual)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
residual = x
|
|
y = self.in_conv(x)
|
|
for blk in self.blocks:
|
|
y = blk(y)
|
|
y = self.out_conv(y)
|
|
if self.global_residual and (y.shape == residual.shape):
|
|
y = y + residual
|
|
return y
|
|
|
|
|
|
class Upsampler(nn.Module):
|
|
def __init__(
|
|
self,
|
|
z_channels: int,
|
|
out_channels: int,
|
|
block_out_channels: tuple[int, ...],
|
|
num_res_blocks: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.num_res_blocks = num_res_blocks
|
|
self.block_out_channels = block_out_channels
|
|
self.z_channels = z_channels
|
|
|
|
ch = block_out_channels[0]
|
|
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
|
|
|
self.up = nn.ModuleList()
|
|
|
|
for i, tgt in enumerate(block_out_channels):
|
|
stage = nn.Module()
|
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
|
out_channels=tgt,
|
|
temb_channels=0,
|
|
conv_shortcut=False,
|
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
|
for j in range(num_res_blocks + 1)])
|
|
ch = tgt
|
|
self.up.append(stage)
|
|
|
|
self.norm_out = RMS_norm(ch)
|
|
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
|
|
|
def forward(self, z):
|
|
"""
|
|
Args:
|
|
z: (B, C, T, H, W)
|
|
target_shape: (H, W)
|
|
"""
|
|
# z to block_in
|
|
repeats = self.block_out_channels[0] // (self.z_channels)
|
|
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
|
|
|
# upsampling
|
|
for stage in self.up:
|
|
for blk in stage.block:
|
|
x = blk(x)
|
|
|
|
out = self.conv_out(F.silu(self.norm_out(x)))
|
|
return out
|
|
|
|
UPSAMPLERS = {
|
|
"720p": SRModel3DV2,
|
|
"1080p": Upsampler,
|
|
}
|
|
|
|
class HunyuanVideo15SRModel():
|
|
def __init__(self, model_type, config):
|
|
self.load_device = model_management.vae_device()
|
|
offload_device = model_management.vae_offload_device()
|
|
self.dtype = model_management.vae_dtype(self.load_device)
|
|
self.model_class = UPSAMPLERS.get(model_type)
|
|
self.model = self.model_class(**config).eval()
|
|
|
|
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
|
|
def load_sd(self, sd):
|
|
return self.model.load_state_dict(sd, strict=True)
|
|
|
|
def get_sd(self):
|
|
return self.model.state_dict()
|
|
|
|
def resample_latent(self, latent):
|
|
model_management.load_model_gpu(self.patcher)
|
|
return self.model(latent.to(self.load_device))
|