mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 09:42:29 +08:00
Implement hunyuan image refiner model. (#9817)
This commit is contained in:
parent
18de0b2830
commit
33bd9ed9cb
@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat):
|
|||||||
|
|
||||||
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
|
||||||
|
|
||||||
|
class HunyuanImage21Refiner(LatentFormat):
|
||||||
|
latent_channels = 64
|
||||||
|
latent_dimensions = 3
|
||||||
|
scale_factor = 1.03682
|
||||||
|
|
||||||
class Hunyuan3Dv2(LatentFormat):
|
class Hunyuan3Dv2(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -278,6 +278,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
guiding_frame_index=None,
|
guiding_frame_index=None,
|
||||||
ref_latent=None,
|
ref_latent=None,
|
||||||
|
disable_time_r=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@ -288,7 +289,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||||
|
|
||||||
if self.time_r_in is not None:
|
if (self.time_r_in is not None) and (not disable_time_r):
|
||||||
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
|
||||||
if len(w) > 0:
|
if len(w) > 0:
|
||||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||||
@ -428,14 +429,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
|
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
if len(self.patch_size) == 3:
|
if len(self.patch_size) == 3:
|
||||||
img_ids = self.img_ids(x)
|
img_ids = self.img_ids(x)
|
||||||
@ -443,5 +444,5 @@ class HunyuanVideo(nn.Module):
|
|||||||
else:
|
else:
|
||||||
img_ids = self.img_ids_2d(x)
|
img_ids = self.img_ids_2d(x)
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
|
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||||
return out
|
return out
|
||||||
|
|||||||
268
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
268
comfy/ldm/hunyuan_video/vae_refiner.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.ldm.models.autoencoder
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
class RMS_norm(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
shape = (dim, 1, 1, 1)
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.empty(shape))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||||
|
|
||||||
|
class DnSmpl(nn.Module):
|
||||||
|
def __init__(self, ic, oc, tds=True):
|
||||||
|
super().__init__()
|
||||||
|
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||||
|
assert oc % fct == 0
|
||||||
|
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
|
||||||
|
|
||||||
|
self.tds = tds
|
||||||
|
self.gs = fct * ic // oc
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r1 = 2 if self.tds else 1
|
||||||
|
h = self.conv(x)
|
||||||
|
|
||||||
|
if self.tds:
|
||||||
|
hf = h[:, :, :1, :, :]
|
||||||
|
b, c, f, ht, wd = hf.shape
|
||||||
|
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||||
|
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||||
|
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||||
|
hf = torch.cat([hf, hf], dim=1)
|
||||||
|
|
||||||
|
hn = h[:, :, 1:, :, :]
|
||||||
|
b, c, frms, ht, wd = hn.shape
|
||||||
|
nf = frms // r1
|
||||||
|
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||||
|
|
||||||
|
h = torch.cat([hf, hn], dim=2)
|
||||||
|
|
||||||
|
xf = x[:, :, :1, :, :]
|
||||||
|
b, ci, f, ht, wd = xf.shape
|
||||||
|
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
|
||||||
|
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||||
|
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = xf.shape
|
||||||
|
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
|
xn = x[:, :, 1:, :, :]
|
||||||
|
b, ci, frms, ht, wd = xn.shape
|
||||||
|
nf = frms // r1
|
||||||
|
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = xn.shape
|
||||||
|
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||||
|
sc = torch.cat([xf, xn], dim=2)
|
||||||
|
else:
|
||||||
|
b, c, frms, ht, wd = h.shape
|
||||||
|
nf = frms // r1
|
||||||
|
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||||
|
|
||||||
|
b, ci, frms, ht, wd = x.shape
|
||||||
|
nf = frms // r1
|
||||||
|
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||||
|
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||||
|
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||||
|
B, C, T, H, W = sc.shape
|
||||||
|
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||||
|
|
||||||
|
return h + sc
|
||||||
|
|
||||||
|
|
||||||
|
class UpSmpl(nn.Module):
|
||||||
|
def __init__(self, ic, oc, tus=True):
|
||||||
|
super().__init__()
|
||||||
|
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||||
|
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
|
||||||
|
|
||||||
|
self.tus = tus
|
||||||
|
self.rp = fct * oc // ic
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r1 = 2 if self.tus else 1
|
||||||
|
h = self.conv(x)
|
||||||
|
|
||||||
|
if self.tus:
|
||||||
|
hf = h[:, :, :1, :, :]
|
||||||
|
b, c, f, ht, wd = hf.shape
|
||||||
|
nc = c // (2 * 2)
|
||||||
|
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
|
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
|
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
hf = hf[:, : hf.shape[1] // 2]
|
||||||
|
|
||||||
|
hn = h[:, :, 1:, :, :]
|
||||||
|
b, c, frms, ht, wd = hn.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
h = torch.cat([hf, hn], dim=2)
|
||||||
|
|
||||||
|
xf = x[:, :, :1, :, :]
|
||||||
|
b, ci, f, ht, wd = xf.shape
|
||||||
|
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
|
||||||
|
b, c, f, ht, wd = xf.shape
|
||||||
|
nc = c // (2 * 2)
|
||||||
|
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
|
||||||
|
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||||
|
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
xn = x[:, :, 1:, :, :]
|
||||||
|
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||||
|
b, c, frms, ht, wd = xn.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
sc = torch.cat([xf, xn], dim=2)
|
||||||
|
else:
|
||||||
|
b, c, frms, ht, wd = h.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||||
|
b, c, frms, ht, wd = sc.shape
|
||||||
|
nc = c // (r1 * 2 * 2)
|
||||||
|
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||||
|
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||||
|
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||||
|
|
||||||
|
return h + sc
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
|
||||||
|
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
for j in range(num_res_blocks)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
|
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
|
||||||
|
ch = nxt
|
||||||
|
self.down.append(stage)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
|
||||||
|
self.norm_out = RMS_norm(ch)
|
||||||
|
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||||
|
|
||||||
|
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
x = self.conv_in(x)
|
||||||
|
|
||||||
|
for stage in self.down:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'downsample'):
|
||||||
|
x = stage.downsample(x)
|
||||||
|
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
b, c, t, h, w = x.shape
|
||||||
|
grp = c // (self.z_channels << 1)
|
||||||
|
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||||
|
|
||||||
|
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||||
|
out = self.regul(out)[0]
|
||||||
|
|
||||||
|
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f_times_2, c, h, w = out.shape
|
||||||
|
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||||
|
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||||
|
super().__init__()
|
||||||
|
block_out_channels = block_out_channels[::-1]
|
||||||
|
self.z_channels = z_channels
|
||||||
|
self.block_out_channels = block_out_channels
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
|
||||||
|
ch = block_out_channels[0]
|
||||||
|
self.conv_in = VideoConv3d(z_channels, ch, 3)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
depth = (ffactor_spatial >> 1).bit_length()
|
||||||
|
depth_temporal = (ffactor_temporal >> 1).bit_length()
|
||||||
|
|
||||||
|
for i, tgt in enumerate(block_out_channels):
|
||||||
|
stage = nn.Module()
|
||||||
|
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||||
|
out_channels=tgt,
|
||||||
|
temb_channels=0,
|
||||||
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
|
for j in range(num_res_blocks + 1)])
|
||||||
|
ch = tgt
|
||||||
|
if i < depth:
|
||||||
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
|
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
|
||||||
|
ch = nxt
|
||||||
|
self.up.append(stage)
|
||||||
|
|
||||||
|
self.norm_out = RMS_norm(ch)
|
||||||
|
self.conv_out = VideoConv3d(ch, out_channels, 3)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
b, f, c, h, w = z.shape
|
||||||
|
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||||
|
z = z.permute(0, 2, 1, 3, 4)
|
||||||
|
z = z[:, :, 1:]
|
||||||
|
|
||||||
|
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||||
|
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||||
|
|
||||||
|
for stage in self.up:
|
||||||
|
for blk in stage.block:
|
||||||
|
x = blk(x)
|
||||||
|
if hasattr(stage, 'upsample'):
|
||||||
|
x = stage.upsample(x)
|
||||||
|
|
||||||
|
return self.conv_out(F.silu(self.norm_out(x)))
|
||||||
@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
|
|||||||
z = posterior.mode()
|
z = posterior.mode()
|
||||||
return z, None
|
return z, None
|
||||||
|
|
||||||
|
class EmptyRegularizer(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||||
|
return z, None
|
||||||
|
|
||||||
class AbstractAutoencoder(torch.nn.Module):
|
class AbstractAutoencoder(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||||
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
|
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
|
|||||||
self.use_conv_shortcut = conv_shortcut
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
self.swish = torch.nn.SiLU(inplace=True)
|
self.swish = torch.nn.SiLU(inplace=True)
|
||||||
self.norm1 = Normalize(in_channels)
|
self.norm1 = norm_op(in_channels)
|
||||||
self.conv1 = conv_op(in_channels,
|
self.conv1 = conv_op(in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = ops.Linear(temb_channels,
|
self.temb_proj = ops.Linear(temb_channels,
|
||||||
out_channels)
|
out_channels)
|
||||||
self.norm2 = Normalize(out_channels)
|
self.norm2 = norm_op(out_channels)
|
||||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.conv2 = conv_op(out_channels,
|
self.conv2 = conv_op(out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -305,11 +305,11 @@ def vae_attention():
|
|||||||
return normal_attention
|
return normal_attention
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = norm_op(in_channels)
|
||||||
self.q = conv_op(in_channels,
|
self.q = conv_op(in_channels,
|
||||||
in_channels,
|
in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
|
|||||||
@ -1432,3 +1432,23 @@ class HunyuanImage21(BaseModel):
|
|||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class HunyuanImage21Refiner(HunyuanImage21):
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
shape_image = list(noise.shape)
|
||||||
|
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||||
|
else:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
image = self.process_latent_in(image)
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
return image
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||||
|
return out
|
||||||
|
|||||||
17
comfy/sd.py
17
comfy/sd.py
@ -285,6 +285,7 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
|
self.not_video = False
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -409,6 +410,20 @@ class VAE:
|
|||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
self.downscale_index_formula = (8, 32, 32)
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||||
|
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
|
self.downscale_ratio = 16
|
||||||
|
self.upscale_ratio = 16
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.not_video = True
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||||
|
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||||
|
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
@ -669,7 +684,7 @@ class VAE:
|
|||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
|
|||||||
@ -1321,6 +1321,23 @@ class HunyuanImage21(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
class HunyuanImage21Refiner(HunyuanVideo):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "hunyuan_video",
|
||||||
|
"patch_size": [1, 1, 1],
|
||||||
|
"vec_in_dim": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.HunyuanImage21Refiner
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent:
|
|||||||
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
class HunyuanRefinerLatent:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"positive": ("CONDITIONING", ),
|
||||||
|
"negative": ("CONDITIONING", ),
|
||||||
|
"latent": ("LATENT", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, positive, negative, latent):
|
||||||
|
latent = latent["samples"]
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent})
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
return (positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
"HunyuanImageToVideo": HunyuanImageToVideo,
|
||||||
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
||||||
|
"HunyuanRefinerLatent": HunyuanRefinerLatent,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user