Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-12-17 13:57:30 +03:00 committed by GitHub
commit dc574cdc47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 413855 additions and 122 deletions

View File

@ -157,16 +157,23 @@ vae_conversion_map_attn = [
]
def reshape_weight_for_sd(w):
def reshape_weight_for_sd(w, conv3d=False):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
if conv3d:
return w.reshape(*w.shape, 1, 1, 1)
else:
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
mapping = {k: k for k in vae_state_dict.keys()}
conv3d = False
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
if v.endswith(".conv.weight"):
if not conv3d and vae_state_dict[k].ndim == 5:
conv3d = True
mapping[k] = v
for k, v in mapping.items():
if "attentions" in k:
@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict

View File

@ -352,3 +352,7 @@ class LTXV(LatentFormat):
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
scale_factor = 0.476986

View File

@ -114,7 +114,7 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -141,8 +141,9 @@ class DoubleStreamBlock(nn.Module):
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -160,12 +161,22 @@ class DoubleStreamBlock(nn.Module):
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2), pe=pe)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
@ -217,7 +228,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -226,7 +237,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output

View File

@ -1,14 +1,15 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
@ -33,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -4,6 +4,8 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from .layers import (
DoubleStreamBlock,
@ -14,9 +16,6 @@ from .layers import (
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
@ -98,8 +97,9 @@ class Flux(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@ -124,14 +124,27 @@ class Flux(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
@ -146,13 +159,20 @@ class Flux(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
@ -181,5 +201,5 @@ class Flux(nn.Module):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -0,0 +1,330 @@
#Based on Flux code because of weird hunyuan video code license.
import torch
import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass
from einops import repeat
from torch import Tensor, nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding
)
import comfy.ldm.common_dit
@dataclass
class HunyuanVideoParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: list
qkv_bias: bool
guidance_embed: bool
class SelfAttentionRef(nn.Module):
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
class TokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.heads = heads
mlp_hidden_dim = hidden_size * 4
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.blocks = nn.ModuleList(
[
TokenRefinerBlock(
hidden_size=hidden_size,
heads=heads,
dtype=dtype,
device=device,
operations=operations
)
for _ in range(num_blocks)
]
)
def forward(self, x, c, mask):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
return x
class TokenRefiner(nn.Module):
def __init__(
self,
text_dim,
hidden_size,
heads,
num_blocks,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
def forward(
self,
x,
timesteps,
mask,
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
attn_mask[:, 0, img_len:] = txt_mask
else:
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((img, txt), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, : img_len] += add
img = img[:, : img_len]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape)
return img
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
return out

View File

@ -162,12 +162,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
else:
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

View File

@ -340,12 +340,9 @@ except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
disabled_xformers = False
if BROKEN_XFORMERS:
@ -360,35 +357,44 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
# b h k d -> b k h d
q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
(q, k, v),
)
# actually do the reshaping
else:
dim_head //= heads
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a singleton heads dimension
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# pad to a multiple of 8
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
# in flux, this matrix ends up being over 1GB
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
# doesn't this remove the padding again??
mask = mask_out[..., :mask.shape[-1]]
mask = mask.expand(b, heads, -1, -1)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
@ -410,15 +416,34 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
(q, k, v),
)
if SDP_BATCH_LIMIT >= q.shape[0]:
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=mask, dropout_p=0.0, is_causal=False).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -43,51 +43,100 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
self.padding_mode = padding_mode
if padding != 0:
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
else:
kwargs["padding"] = padding
self.padding = padding
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
if self.padding != 0:
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
super().__init__()
self.with_conv = with_conv
self.scale_factor = scale_factor
if self.with_conv:
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
try:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
except: #operation not implemented for bf16
b, c, h, w = x.shape
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
del x
x = out
scale_factor = self.scale_factor
if not isinstance(scale_factor, tuple):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, self.scale_factor)
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels,
self.conv = conv_op(in_channels,
in_channels,
kernel_size=3,
stride=2,
stride=stride,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
if x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
@ -96,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
dropout, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -105,7 +154,7 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = ops.Conv2d(in_channels,
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
@ -115,20 +164,20 @@ class ResnetBlock(nn.Module):
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = ops.Conv2d(out_channels,
self.conv2 = conv_op(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ops.Conv2d(in_channels,
self.conv_shortcut = conv_op(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = ops.Conv2d(in_channels,
self.nin_shortcut = conv_op(in_channels,
out_channels,
kernel_size=1,
stride=1,
@ -194,21 +243,25 @@ def slice_attention(q, k, v):
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
orig_shape = q.shape
b = orig_shape[0]
c = orig_shape[1]
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
q = q.reshape(b, c, -1)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, -1) # b,c,hw
v = v.reshape(b, c, -1)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
h_ = r1.reshape(orig_shape)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
@ -216,14 +269,16 @@ def xformers_attention(q, k, v):
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
out = out.transpose(1, 2).reshape(orig_shape)
except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@ -231,35 +286,35 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = ops.Conv2d(in_channels,
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = ops.Conv2d(in_channels,
self.k = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = ops.Conv2d(in_channels,
self.v = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = ops.Conv2d(in_channels,
self.proj_out = conv_op(in_channels,
in_channels,
kernel_size=1,
stride=1,
@ -289,8 +344,8 @@ class AttnBlock(nn.Module):
return x+h_
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return AttnBlock(in_channels)
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
return AttnBlock(in_channels, conv_op=conv_op)
class Model(nn.Module):
@ -449,6 +504,7 @@ class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
@ -459,8 +515,15 @@ class Encoder(nn.Module):
self.resolution = resolution
self.in_channels = in_channels
if conv3d:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# downsampling
self.conv_in = ops.Conv2d(in_channels,
self.conv_in = conv_op(in_channels,
self.ch,
kernel_size=3,
stride=1,
@ -479,15 +542,20 @@ class Encoder(nn.Module):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
stride = 2
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
@ -496,16 +564,18 @@ class Encoder(nn.Module):
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# end
self.norm_out = Normalize(block_in)
self.conv_out = ops.Conv2d(block_in,
self.conv_out = conv_op(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
@ -543,6 +613,8 @@ class Decoder(nn.Module):
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
conv3d=False,
time_compress=None,
**ignorekwargs):
super().__init__()
self.ch = ch
@ -554,6 +626,14 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
mid_attn_conv_op = ops.Conv2d
# compute block_in and curr_res at lowest res
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
@ -562,7 +642,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = ops.Conv2d(z_channels,
self.conv_in = conv_op(z_channels,
block_in,
kernel_size=3,
stride=1,
@ -573,12 +653,14 @@ class Decoder(nn.Module):
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = attn_op(block_in)
dropout=dropout,
conv_op=conv_op)
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
dropout=dropout,
conv_op=conv_op)
# upsampling
self.up = nn.ModuleList()
@ -590,15 +672,21 @@ class Decoder(nn.Module):
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
dropout=dropout,
conv_op=conv_op))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(attn_op(block_in))
attn.append(attn_op(block_in, conv_op=conv_op))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
scale_factor = 2.0
if time_compress is not None:
if i_level > math.log2(time_compress):
scale_factor = (1.0, 2.0, 2.0)
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@ -194,6 +194,7 @@ def make_time_attn(
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
conv_op=ops.Conv2d,
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy

View File

@ -31,6 +31,7 @@ import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.model_management
import comfy.patcher_extension
@ -686,6 +687,7 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
@ -766,6 +768,16 @@ class Flux(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# upscale the attention mask, since now we
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
@ -807,3 +819,21 @@ class LTXV(BaseModel):
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
return out

View File

@ -133,6 +133,26 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["image_model"] = "hydit1"
return unet_config
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = 16
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"

View File

@ -31,6 +31,7 @@ import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.model_patcher
import comfy.lora
@ -306,8 +307,8 @@ class VAE:
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@ -344,6 +345,17 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -544,6 +556,7 @@ class CLIPType(Enum):
FLUX = 6
MOCHI = 7
LTXV = 8
HUNYUAN_VIDEO = 9
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@ -559,6 +572,7 @@ class TEModel(Enum):
T5_XXL = 4
T5_XL = 5
T5_BASE = 6
LLAMA3_8 = 7
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -575,6 +589,8 @@ def detect_te_model(sd):
return TEModel.T5_XL
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@ -587,6 +603,14 @@ def t5xxl_detect(clip_data):
return {}
def llama_detect(clip_data):
weight_name = "model.layers.0.self_attn.k_proj.weight"
for sd in clip_data:
if weight_name in sd:
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts
@ -652,6 +676,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.FLUX:
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -420,7 +420,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data={}):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
@ -433,11 +433,16 @@ class SDTokenizer:
self.tokens_start = 1
self.start_token = empty[0]
if has_end_token:
self.end_token = empty[1]
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
if has_end_token:
if end_token is not None:
self.end_token = end_token
else:
self.end_token = empty[0]
if pad_token is not None:

View File

@ -12,6 +12,7 @@ import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
from . import supported_models_base
from . import latent_formats
@ -738,6 +739,54 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
class HunyuanVideo(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_video",
}
sampling_settings = {
"shift": 7.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 2.0 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo(self, device=device)
return out
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
key_out = k
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
out_sd[key_out] = state_dict[k]
return out_sd
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
models += [SVD_img2vid]

View File

@ -0,0 +1,111 @@
from comfy import sd1_clip
import comfy.model_management
import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.llama_template = """<|start_header_id|>system<|end_header_id|>
Describe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>""" # 93 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.llama.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.llama.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_llama = token_weight_pairs["llama"]
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
for i, v in enumerate(token_weight_pairs_llama[0]):
if v[0] == 128007: # <|end_header_id|>
template_end = i
llama_out = llama_out[:, template_end:]
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return llama_out, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_

View File

@ -0,0 +1,221 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
@dataclass
class Llama2Config:
vocab_size: int = 128320
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return (cos, sin)
def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1)
sin = freqs_cis[1].unsqueeze(1)
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
class Attention(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = self.hidden_size // self.num_heads
ops = ops or nn
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
)
x = residual + x
# MLP
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
self.layers = nn.ModuleList([
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embed_tokens(x, out_dtype=dtype)
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
x.shape[1],
self.config.rope_theta,
device=x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
x = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
)
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate
class Llama2(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Llama2Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.embed_tokens = embeddings
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,8 @@ import numpy as np
from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -873,5 +875,46 @@ def reshape_mask(input_mask, output_shape):
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
if mask.shape[1] < output_shape[1]:
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
mask = repeat_to_batch_size(mask, output_shape[0])
return mask
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
hi, wi = img_size_in
ho, wo = img_size_out
# if it's already the correct size, no need to do anything
if (hi, wi) == (ho, wo):
return mask
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if mask.ndim != 3:
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
txt_tokens = mask.shape[1] - (hi * wi)
# quadrants of the mask
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
img_to_img = mask[:, txt_tokens:, txt_tokens:]
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
# convert to 1d x 2d, interpolate, then back to 1d x 1d
txt_to_img = rearrange (txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
txt_to_img = rearrange (txt_to_img, "b t h w -> b t (h w)")
# this one is hard because we have to do it twice
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
img_to_img = rearrange (img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
img_to_img = rearrange (img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
# convert to 2d x 1d, interpolate, then back to 1d x 1d
img_to_txt = rearrange (img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
img_to_txt = rearrange (img_to_txt, "b t h w -> b (h w) t")
# reassemble the mask from blocks
out = torch.cat([
torch.cat([txt_to_txt, txt_to_img], dim=2),
torch.cat([img_to_txt, img_to_img], dim=2)],
dim=1
)
return out

View File

@ -1,3 +1,8 @@
import nodes
import torch
import comfy.model_management
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
@ -17,7 +22,23 @@ class CLIPTextEncodeHunyuanDiT:
return (clip.encode_from_tokens_scheduled(tokens), )
class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
}

View File

@ -929,7 +929,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -947,6 +947,8 @@ class DualCLIPLoader:
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
@ -1008,23 +1010,58 @@ class StyleModelApply:
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
"strength_type": (["multiply", "attn_bias"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
n = cond.shape[1]
c_out = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
c.append(n)
return (c, )
(txt, keys) = t
keys = keys.copy()
if strength_type == "attn_bias" and strength != 1.0:
# math.log raises an error if the argument is zero
# torch.log returns -inf, which is what we want
attn_bias = torch.log(torch.Tensor([strength]))
# get the size of the mask image
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
n_ref = mask_ref_size[0] * mask_ref_size[1]
n_txt = txt.shape[1]
# grab the existing mask
mask = keys.get("attention_mask", None)
# create a default mask if it doesn't exist
if mask is None:
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
# convert the mask dtype, because it might be boolean
# we want it to be interpreted as a bias
if mask.dtype == torch.bool:
# log(True) = log(1) = 0
# log(False) = log(0) = -inf
mask = torch.log(mask.to(dtype=torch.float16))
# now we make the mask bigger to add space for our new tokens
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
# copy over the old mask, in quandrants
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
# now fill in the attention bias to our redux tokens
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
keys["attention_mask"] = new_mask.to(txt.device)
keys["attention_mask_img_shape"] = mask_ref_size
c_out.append([torch.cat((txt, cond), dim=1), keys])
return (c_out,)
class unCLIPConditioning:
@classmethod