Merge branch 'comfyanonymous:master' into feature/custom-node-paths-cli-args

This commit is contained in:
Sas van Gulik 2025-11-27 11:45:16 +01:00 committed by GitHub
commit e39b882c50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 4268 additions and 1860 deletions

View File

@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -6,6 +6,7 @@ class LatentFormat:
latent_dimensions = 2 latent_dimensions = 2
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None taesd_decoder_name = None
def process_in(self, latent): def process_in(self, latent):
@ -178,6 +179,54 @@ class Flux(SD3):
def process_out(self, latent): def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor return (latent / self.scale_factor) + self.shift_factor
class Flux2(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors =[
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3 latent_dimensions = 3

View File

@ -179,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit: if i not in self.skip_mmdit:
double_mod = ( double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_img", idx=i),
@ -222,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i) single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x))) return self.out_layer(self.silu(self.in_layer(x)))
@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass @dataclass
@ -98,11 +98,11 @@ class ModulationOut:
class Modulation(nn.Module): class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.is_double = double self.is_double = double
self.multiplier = 6 if double else 3 self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple: def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2: if vec.ndim == 2:
@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor return tensor
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module):
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), if mlp_silu_act:
nn.GELU(approximate="tanh"), self.img_mlp = nn.Sequential(
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
) SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation: if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), if mlp_silu_act:
nn.GELU(approximate="tanh"), self.txt_mlp = nn.Sequential(
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
) SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module):
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: float = None, qk_scale: float = None,
modulation=True, modulation=True,
mlp_silu_act=False,
bias=True,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None
@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5 self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
# qkv and mlp_in # qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out # proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
if modulation: if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else: else:
@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module):
else: else:
mod = vec mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv del qkv
@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module): class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2: if vec.ndim == 2:

View File

@ -15,6 +15,7 @@ from .layers import (
MLPEmbedder, MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
Modulation
) )
@dataclass @dataclass
@ -33,6 +34,11 @@ class FluxParams:
patch_size: int patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
class Flux(nn.Module): class Flux(nn.Module):
@ -58,13 +64,17 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size self.hidden_size = params.hidden_size
self.num_heads = params.num_heads self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = ( self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
) )
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
@ -73,6 +83,9 @@ class Flux(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -81,13 +94,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
if final_layer: if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.global_modulation:
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
)
def forward_orig( def forward_orig(
self, self,
@ -103,9 +133,6 @@ class Flux(nn.Module):
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -118,9 +145,17 @@ class Flux(nn.Module):
if guidance is not None: if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) if self.vector_in is not None:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@ -136,7 +171,10 @@ class Flux(nn.Module):
pe = None pe = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -177,7 +215,13 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -207,7 +251,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -234,10 +278,10 @@ class Flux(nn.Module):
h_offset += rope_options.get("shift_y", 0.0) h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0) w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@ -259,10 +303,10 @@ class Flux(nn.Module):
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset") ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method == "index":
index += 1 index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif ref_latents_method == "uxo": elif ref_latents_method == "uxo":
@ -286,7 +330,11 @@ class Flux(nn.Module):
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module):
attn_mask = None attn_mask = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1) img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension import comfy.patcher_extension
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int, n_heads: int,
n_kv_heads: Optional[int], n_kv_heads: Optional[int],
qk_norm: bool, qk_norm: bool,
out_bias: bool = False,
operation_settings={}, operation_settings={},
): ):
""" """
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear( self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim, n_heads * self.head_dim,
dim, dim,
bias=False, bias=out_bias,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq) xq = self.q_norm(xq)
xk = self.k_norm(xk) xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep >= 1:
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float, norm_eps: float,
qk_norm: bool, qk_norm: bool,
modulation=True, modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={}, operation_settings={},
) -> None: ) -> None:
""" """
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward( self.feed_forward = FeedForward(
dim=dim, dim=dim,
hidden_dim=4 * dim, hidden_dim=dim,
multiple_of=multiple_of, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings, operation_settings=operation_settings,
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
self.adaLN_modulation = nn.Sequential( if z_image_modulation:
nn.SiLU(), self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(dim, 1024), min(dim, 256),
4 * dim, 4 * dim,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
), ),
) )
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward( def forward(
self, self,
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT. The final layer of NextDiT.
""" """
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__() super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm( self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size, hidden_size,
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(hidden_size, 1024), min(hidden_size, min_mod),
hidden_size, hidden_size,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32, n_heads: int = 32,
n_kv_heads: Optional[int] = None, n_kv_heads: Optional[int] = None,
multiple_of: int = 256, multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None, ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = False, qk_norm: bool = False,
cap_feat_dim: int = 5120, cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56), axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512), axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear( self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels, in_features=patch_size * patch_size * in_channels,
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps, norm_eps,
qk_norm, qk_norm,
modulation=True, modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
] ]
) )
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier, ffn_dim_multiplier,
norm_eps, norm_eps,
qk_norm, qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims self.axes_dims = axes_dims
self.axes_lens = axes_lens self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
@ -503,108 +506,42 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
dtype = x[0].dtype
if cap_mask is not None: if self.pad_tokens_multiple is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist() pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
else: cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask): cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x] B, C, H, W = x.shape
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max( H_tokens, W_tokens = H // pH, W // pW
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
max_cap_len = max(l_effective_cap_len) x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
max_img_len = max(l_effective_img_len) x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
for i in range(bsz): freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image padded_img_mask = None
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -627,7 +564,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features y: (N,) tensor of text tokens/features
""" """
t = self.t_embedder(t, dtype=x.dtype) # (N, D) t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
from einops import rearrange
import comfy.model_management
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False): def __init__(self, sample: bool = False):
@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ddconfig.get("batch_norm_latent", False):
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
self.bn.eval()
else:
self.bn = None
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params() params = super().get_autoencoder_params()
return params return params
@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0) z = torch.cat(z, 0)
z, reg_log = self.regularization(z) z, reg_log = self.regularization(z)
if self.bn is not None:
z = rearrange(z,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = torch.nn.functional.batch_norm(z,
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
momentum=self.bn_momentum,
eps=self.bn_eps)
if return_reg_log: if return_reg_log:
return z, reg_log return z, reg_log
return z return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.bn is not None:
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
z = z * s + m
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
if self.max_batch_size is None: if self.max_batch_size is None:
dec = self.post_quant_conv(z) dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs) dec = self.decoder(dec, **decoder_kwargs)

View File

@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size

View File

@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -898,12 +898,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None: if attention_mask is not None:
shape = kwargs["noise"].shape shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"] mask_ref_size = kwargs.get("attention_mask_img_shape", None)
# the model will pad to the patch size, and then divide if mask_ref_size is not None:
# essentially dividing and rounding up # the model will pad to the patch size, and then divide
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) # essentially dividing and rounding up
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5) guidance = kwargs.get("guidance", 3.5)
if guidance is not None: if guidance is not None:
@ -925,9 +926,19 @@ class Flux(BaseModel):
out = {} out = {}
ref_latents = kwargs.get("reference_latents", None) ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None: if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
target_text_len = 512
if cross_attn.shape[1] < target_text_len:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class GenmoMochi(BaseModel): class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1103,9 +1114,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum(): if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):

View File

@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
dit_config["axes_dim"] = [32, 32, 32, 32]
dit_config["num_heads"] = 48
dit_config["mlp_ratio"] = 3.0
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
patch_size = 1
else:
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["num_heads"] = 24
dit_config["mlp_ratio"] = 4.0
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
patch_size = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
patch_size = 2 dit_config["hidden_size"] = 3072
dit_config["context_in_dim"] = 4096
dit_config["patch_size"] = patch_size dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix) in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys: if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) w = state_dict[in_key]
dit_config["out_channels"] = 16 dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys: if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma" dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64 dit_config["in_channels"] = 64
@ -388,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512] if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1

View File

@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory() loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
@ -1012,7 +1012,7 @@ def force_channels_last():
STREAMS = {} STREAMS = {}
NUM_STREAMS = 1 NUM_STREAMS = 0
if args.async_offload: if args.async_offload:
NUM_STREAMS = 2 NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
@ -1030,7 +1030,7 @@ def current_stream(device):
stream_counters = {} stream_counters = {}
def get_offload_stream(device): def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0) stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1: if NUM_STREAMS == 0:
return None return None
if device in STREAMS: if device in STREAMS:
@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor): def pin_memory(tensor):
global TOTAL_PINNED_MEMORY global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0: if MAX_PINNED_MEMORY <= 0:
return False return False
if type(tensor) is not torch.nn.parameter.Parameter: if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False return False
if not is_device_cpu(tensor.device): if not is_device_cpu(tensor.device):
@ -1124,6 +1125,9 @@ def pin_memory(tensor):
return False return False
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size TOTAL_PINNED_MEMORY += size

View File

@ -132,7 +132,7 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None: if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
@ -148,6 +148,15 @@ class LowVramPatch:
else: else:
return out return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
convert_func = None convert_func = None
@ -231,7 +240,6 @@ class ModelPatcher:
self.object_patches_backup = {} self.object_patches_backup = {}
self.weight_wrapper_patches = {} self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
@ -270,6 +278,9 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'): if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None self.model.current_weight_patches_uuid = None
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
return self.size return self.size
@ -286,7 +297,7 @@ class ModelPatcher:
return self.model.lowvram_patch_counter return self.model.lowvram_patch_counter
def clone(self): def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -663,7 +674,16 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules skip = True # skip random weights in non leaf modules
break break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params)) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -677,20 +697,22 @@ class ModelPatcher:
load_completely = [] load_completely = []
offloaded = [] offloaded = []
offload_buffer = 0
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] module_offload_mem, module_mem, n, m, params = x
m = x[2]
params = x[3]
module_mem = x[0]
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"): if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory: if not lowvram_fits:
offload_buffer = potential_offload
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
lowvram_mem_counter += module_mem lowvram_mem_counter += module_mem
@ -724,9 +746,11 @@ class ModelPatcher:
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory: if full_load or lowvram_fits:
mem_counter += module_mem mem_counter += module_mem
load_completely.append((module_mem, n, m, params)) load_completely.append((module_mem, n, m, params))
else:
offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"): if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -767,7 +791,7 @@ class ModelPatcher:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
@ -779,6 +803,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@ -832,6 +857,7 @@ class ModelPatcher:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"): if hasattr(m, "comfy_patched_weights"):
@ -850,13 +876,14 @@ class ModelPatcher:
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
unload_list.sort() unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
for unload in unload_list: for unload in unload_list:
if memory_to_free < memory_freed: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break break
module_mem = unload[0] module_offload_mem, module_mem, n, m, params = unload
n = unload[1]
m = unload[2] potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -907,15 +934,18 @@ class ModelPatcher:
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
for param in params: for param in params:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context: with wf_context:
weight = weight.to(dtype=dtype) weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight return weight
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@ -540,115 +542,136 @@ if CUBLAS_IS_AVAILABLE:
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS from .quant_ops import QuantizedTensor, QUANT_ALGOS
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp): def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def __init__( class MixedPrecisionOps(manual_cast):
self, _layer_quant_config = layer_quant_config
in_features: int, _compute_dtype = compute_dtype
out_features: int, _full_precision_mm = full_precision_mm
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} class Linear(torch.nn.Module, CastWeightBiasOp):
# self.factory_kwargs = {"device": device, "dtype": dtype} def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.in_features = in_features self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
self.out_features = out_features # self.factory_kwargs = {"device": device, "dtype": dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
def reset_parameters(self): self.tensor_class = None
return None self._full_precision_mm = MixedPrecisionOps._full_precision_mm
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def reset_parameters(self):
strict, missing_keys, unexpected_keys, error_msgs): return None
device = self.factory_kwargs["device"] def _load_from_state_dict(self, state_dict, prefix, local_metadata,
layer_name = prefix.rstrip('.') strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key] device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
if layer_name not in MixedPrecisionOps._layer_quant_config: manually_loaded_keys = [weight_key]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format] if layer_name not in MixedPrecisionOps._layer_quant_config:
self.layout_type = qconfig["comfy_tensor_layout"] self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
weight_scale_key = f"{prefix}weight_scale" qconfig = QUANT_ALGOS[quant_format]
layout_params = { self.layout_type = qconfig["comfy_tensor_layout"]
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter( weight_scale_key = f"{prefix}weight_scale"
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), layout_params = {
requires_grad=False 'scale': state_dict.pop(weight_scale_key, None),
) 'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
for param_name in qconfig["parameters"]: self.weight = torch.nn.Parameter(
param_key = f"{prefix}{param_name}" QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
_v = state_dict.pop(param_key, None) requires_grad=False
if _v is None: )
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
for key in manually_loaded_keys: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias): for key in manually_loaded_keys:
return torch.nn.functional.linear(input, weight, bias) if key in missing_keys:
missing_keys.remove(key)
def forward_comfy_cast_weights(self, input): def _forward(self, input, weight, bias):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) return torch.nn.functional.linear(input, weight, bias)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs): def forward_comfy_cast_weights(self, input):
run_every_op() weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: def forward(self, input, *args, **kwargs):
return self.forward_comfy_cast_weights(input, *args, **kwargs) run_every_op()
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -1,6 +1,7 @@
import torch import torch
import logging import logging
from typing import Tuple, Dict from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {} _LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {} _GENERIC_UTILS = {}
@ -228,6 +229,14 @@ class QuantizedTensor(torch.Tensor):
new_kwargs = dequant_arg(kwargs) new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs) return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self):
return self._qdata.is_contiguous()
# ============================================================================== # ==============================================================================
# Generic Utilities (Layout-Agnostic Operations) # Generic Utilities (Layout-Agnostic Operations)
@ -338,6 +347,18 @@ def generic_copy_(func, args, kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs): def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True return True
@ -373,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
- orig_dtype: Original dtype before quantization (for casting back) - orig_dtype: Original dtype before quantization (for casting back)
""" """
@classmethod @classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype orig_dtype = tensor.dtype
if scale is None: if scale is None:
@ -383,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale) scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) if inplace_ops:
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality' tensor *= (1.0 / scale).to(tensor.dtype)
# lp_amax = torch.finfo(dtype).max else:
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = { layout_params = {
'scale': scale, 'scale': scale,
'orig_dtype': orig_dtype 'orig_dtype': orig_dtype
} }
return qdata, layout_params return tensor, layout_params
@staticmethod @staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs): def dequantize(qdata, scale, orig_dtype, **kwargs):

View File

@ -52,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -356,7 +357,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32: elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -382,6 +383,17 @@ class VAE:
self.upscale_ratio = 4 self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
if 'bn.running_mean' in sd:
ddconfig["batch_norm_latent"] = True
self.downscale_ratio *= 2
self.upscale_ratio *= 2
self.latent_channels *= 4
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd: if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else: else:
@ -917,7 +929,12 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -935,6 +952,10 @@ class TEModel(Enum):
QWEN25_7B = 11 QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12 BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13 GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -967,6 +988,15 @@ def detect_te_model(sd):
if weight.shape[0] == 512: if weight.shape[0] == 512:
return TEModel.QWEN25_7B return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd: if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
return TEModel.LLAMA3_8 return TEModel.LLAMA3_8
return None return None
@ -1081,6 +1111,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
@ -1142,6 +1179,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0 parameters = 0
for c in clip_data: for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

View File

@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
operations = model_options.get("custom_operations", None) operations = model_options.get("custom_operations", None)
scaled_fp8 = None scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
if operations is None: if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None) layer_quant_config = None
if scaled_fp8 is not None: if quantization_metadata is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else: else:
operations = comfy.ops.manual_cast # Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options): def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx) layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all": if isinstance(self.layer, list) or self.layer == "all":
pass pass
elif layer_idx is None or abs(layer_idx) > self.num_layers: elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask_model = attention_mask attention_mask_model = attention_mask
if self.layer == "all": if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all" intermediate_output = "all"
else: else:
intermediate_output = self.layer_idx intermediate_output = self.layer_idx

View File

@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -741,6 +742,37 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out return out
class Flux2(Flux):
unet_config = {
"image_model": "flux2",
}
sampling_settings = {
"shift": 2.02,
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "mochi_preview", "image_model": "mochi_preview",
@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0, "shift": 6.0,
} }
memory_usage_factor = 1.2 memory_usage_factor = 1.4
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE): class WAN21_T2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1422,6 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -1,10 +1,13 @@
from comfy import sd1_clip from comfy import sd1_clip
import comfy.text_encoders.t5 import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.llama
import comfy.model_management import comfy.model_management
from transformers import T5TokenizerFast from transformers import T5TokenizerFast, LlamaTokenizerFast
import torch import torch
import os import os
import json
import base64
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_
def load_mistral_tokenizer(data):
if torch.is_tensor(data):
data = data.numpy().tobytes()
try:
from transformers.integrations.mistral import MistralConverter
except ModuleNotFoundError:
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
mistral_vocab = json.loads(data)
special_tokens = {}
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
if r >= max_vocab:
continue
vocab[base64.b64decode(w["token_bytes"])] = r
for w in mistral_vocab["special_tokens"]:
if "token_bytes" in w:
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
else:
special_tokens[w["token_str"]] = w["rank"]
all_special = []
for v in special_tokens:
all_special.append(v)
special_tokens.update(vocab)
vocab = special_tokens
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
class MistralTokenizerClass:
@staticmethod
def from_pretrained(path, **kwargs):
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40:
textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_

View File

@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict: if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
return out return out

View File

@ -34,6 +34,28 @@ class Llama2Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Mistral3Small24BConfig:
vocab_size: int = 131072
hidden_size: int = 5120
intermediate_size: int = 32768
num_hidden_layers: int = 40
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
vocab_size: int = 151936 vocab_size: int = 151936
@ -56,6 +78,28 @@ class Qwen25_3BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
vocab_size: int = 152064 vocab_size: int = 152064
@ -412,8 +456,12 @@ class Llama2_(nn.Module):
intermediate = None intermediate = None
all_intermediate = None all_intermediate = None
only_layers = None
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output == "all": if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = [] all_intermediate = []
intermediate_output = None intermediate_output = None
elif intermediate_output < 0: elif intermediate_output < 0:
@ -421,7 +469,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer( x = layer(
x=x, x=x,
attention_mask=mask, attention_mask=mask,
@ -435,7 +484,8 @@ class Llama2_(nn.Module):
x = self.norm(x) x = self.norm(x)
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None: if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1) intermediate = torch.cat(all_intermediate, dim=1)
@ -465,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Mistral3Small24B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Mistral3Small24BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module): class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@ -474,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm( lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape) ).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose( weight = weight_decompose(
dora_scale, dora_scale,

View File

@ -8,7 +8,7 @@ import os
import textwrap import textwrap
import threading import threading
from enum import Enum from enum import Enum
from typing import Optional, Type, get_origin, get_args from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker: class TypeTracker:
@ -220,11 +220,18 @@ class AsyncToSyncConverter:
self._async_instance = async_class(*args, **kwargs) self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution) # Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy # Get all annotations from the class hierarchy and resolve string annotations
all_annotations = {} try:
for base_class in reversed(inspect.getmro(async_class)): # get_type_hints resolves string annotations to actual type objects
if hasattr(base_class, "__annotations__"): # This handles classes using 'from __future__ import annotations'
all_annotations.update(base_class.__annotations__) all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations if get_type_hints fails
# (e.g., for undefined forward references)
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped # For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items(): for attr_name, attr_type in all_annotations.items():
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
"""Extract class attributes that are classes themselves.""" """Extract class attributes that are classes themselves."""
class_attributes = [] class_attributes = []
# Get resolved type hints to handle string annotations
try:
type_hints = get_type_hints(async_class)
except Exception:
type_hints = {}
# Look for class attributes that are classes # Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)): for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"): if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr)) class_attributes.append((name, attr))
elif ( elif name in type_hints:
hasattr(async_class, "__annotations__") # Use resolved type hint instead of raw annotation
and name in async_class.__annotations__ annotation = type_hints[name]
):
annotation = async_class.__annotations__[name]
if isinstance(annotation, type): if isinstance(annotation, type):
class_attributes.append((name, annotation)) class_attributes.append((name, annotation))
@ -908,11 +919,15 @@ class AsyncToSyncConverter:
attribute_mappings = {} attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes) # First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy # Resolve string annotations to actual types
all_annotations = {} try:
for base_class in reversed(inspect.getmro(async_class)): all_annotations = get_type_hints(async_class)
if hasattr(base_class, "__annotations__"): except Exception:
all_annotations.update(base_class.__annotations__) # Fallback to raw annotations
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()): for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes: for class_name, class_type in class_attributes:

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import io import io
import av import av
@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0] frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate) return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream] packet.stream = stream_map[packet.stream]
output_container.mux(packet) output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
""" """
Class representing video input from tensors. Class representing video input from tensors.

View File

@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
# ) # )
class Flux2ProGenerateRequest(BaseModel):
prompt: str = Field(...)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
seed: int | None = Field(None)
prompt_upsampling: bool | None = Field(None)
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel): class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.') prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format') input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
class BFLFluxProGenerateResponse(BaseModel): class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.') id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description='URL to poll for the generation result.') polling_url: str = Field(..., description="URL to poll for the generation result.")
cost: float | None = Field(None, description="Price in cents")
class BFLStatus(str, Enum): class BFLStatus(str, Enum):

View File

@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
mimeType: GeminiMimeType | None = Field(None) mimeType: GeminiMimeType | None = Field(None)
class GeminiFileData(BaseModel):
fileUri: str | None = Field(None)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel): class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None) inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None) text: str | None = Field(None)
@ -113,9 +119,9 @@ class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192) maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None) seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None) stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(1, ge=0.0, le=2.0) temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(40, ge=1) topK: int | None = Field(None, ge=1)
topP: float | None = Field(0.95, ge=0.0, le=1.0) topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageConfig(BaseModel): class GeminiImageConfig(BaseModel):

View File

@ -1,34 +1,21 @@
from typing import Optional, Union from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Image2(BaseModel): class VeoRequestInstanceImage(BaseModel):
bytesBase64Encoded: str bytesBase64Encoded: str | None = Field(None)
gcsUri: Optional[str] = None gcsUri: str | None = Field(None)
mimeType: Optional[str] = None mimeType: str | None = Field(None)
class Image3(BaseModel): class VeoRequestInstance(BaseModel):
bytesBase64Encoded: Optional[str] = None image: VeoRequestInstanceImage | None = Field(None)
gcsUri: str lastFrame: VeoRequestInstanceImage | None = Field(None)
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
prompt: str = Field(..., description='Text description of the video') prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum): class VeoRequestParameters(BaseModel):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9']) aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None enhancePrompt: Optional[bool] = None
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
description='Generate audio for the video. Only supported by veo 3 models.', description='Generate audio for the video. Only supported by veo 3 models.',
) )
negativePrompt: Optional[str] = None negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
sampleCount: Optional[int] = None sampleCount: Optional[int] = None
seed: Optional[int] = None seed: Optional[int] = None
storageUri: Optional[str] = Field( storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video' None, description='Optional Cloud Storage URI to upload the video'
) )
resolution: str | None = Field(None)
class VeoGenVidRequest(BaseModel): class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None instances: list[VeoRequestInstance] | None = Field(None)
parameters: Optional[Parameters1] = None parameters: VeoRequestParameters | None = Field(None)
class VeoGenVidResponse(BaseModel): class VeoGenVidResponse(BaseModel):

View File

@ -1,7 +1,7 @@
from inspect import cleandoc from inspect import cleandoc
from typing import Optional
import torch import torch
from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest, BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateRequest,
BFLFluxProGenerateResponse, BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest, BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse, BFLFluxStatusResponse,
BFLStatus, BFLStatus,
Flux2ProGenerateRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
get_number_of_images,
poll_op, poll_op,
resize_mask_to_image, resize_mask_to_image,
sync_op, sync_op,
@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: Optional[torch.Tensor] = None, image_prompt: torch.Tensor | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: Optional[torch.Tensor] = None, input_image: torch.Tensor | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode):
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProImageNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProImageNode",
display_name="Flux 1.1 [pro] Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image_prompt",
optional=True,
),
# "image_prompt_strength": (
# IO.FLOAT,
# {
# "default": 0.1,
# "min": 0.0,
# "max": 1.0,
# "step": 0.01,
# "tooltip": "Blend between the prompt and the image prompt.",
# },
# ),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
prompt_upsampling,
width: int,
height: int,
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
) -> IO.NodeOutput:
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1/generate",
method="POST",
),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
width=width,
height=height,
seed=seed,
image_prompt=image_prompt,
),
)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
""" """
Outpaints image based on prompt. Outpaints image based on prompt.
@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ProImageNode",
display_name="Flux.2 [pro] Image",
category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or edit",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=2048,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=2048,
step=32,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
width: int,
height: int,
seed: int,
prompt_upsampling: bool,
images: torch.Tensor | None = None,
) -> IO.NodeOutput:
reference_images = {}
if images is not None:
if get_number_of_images(images) > 9:
raise ValueError("The current maximum number of supported images is 9.")
for image_index in range(images.shape[0]):
key_name = f"input_image_{image_index + 1}" if image_index else "input_image"
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
width=width,
height=height,
seed=seed,
prompt_upsampling=prompt_upsampling,
**reference_images,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
FluxProUltraImageNode, FluxProUltraImageNode,
# FluxProImageNode,
FluxKontextProImageNode, FluxKontextProImageNode,
FluxKontextMaxImageNode, FluxKontextMaxImageNode,
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode,
] ]

View File

@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
""" """
import base64 import base64
import json
import os import os
import time
import uuid
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Literal from typing import Literal
@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import ( from comfy_api_nodes.apis.gemini_api import (
GeminiContent, GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest, GeminiGenerateContentRequest,
GeminiGenerateContentResponse, GeminiGenerateContentResponse,
GeminiImageConfig, GeminiImageConfig,
@ -38,10 +36,10 @@ from comfy_api_nodes.util import (
get_number_of_images, get_number_of_images,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
upload_images_to_comfyapi,
validate_string, validate_string,
video_to_base64_string, video_to_base64_string,
) )
from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum):
gemini_2_5_flash_image = "gemini-2.5-flash-image" gemini_2_5_flash_image = "gemini-2.5-flash-image"
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: async def create_image_parts(
""" cls: type[IO.ComfyNode],
Convert image tensor input to Gemini API compatible parts. images: torch.Tensor,
image_limit: int = 0,
Args: ) -> list[GeminiPart]:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = [] image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]): if image_limit < 0:
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
max_images=num_url_images,
)
for reference_image_url in reference_images_urls:
image_parts.append(
GeminiPart(
fileData=GeminiFileData(
mimeType=GeminiMimeType.image_png,
fileUri=reference_image_url,
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append( image_parts.append(
GeminiPart( GeminiPart(
inlineData=GeminiInlineData( inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png, mimeType=GeminiMimeType.image_png,
data=image_as_b64, data=tensor_to_base64_string(images[idx]),
) )
) )
) )
@ -104,14 +121,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
List of response parts matching the requested type. List of response parts matching the requested type.
""" """
if response.candidates is None: if response.candidates is None:
if response.promptFeedback.blockReason: if response.promptFeedback and response.promptFeedback.blockReason:
feedback = response.promptFeedback feedback = response.promptFeedback
raise ValueError( raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
) )
raise NotImplementedError( raise ValueError(
"Gemini returned no response candidates. " "Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
"Please report to ComfyUI repository with the example of workflow to reproduce this." "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
) )
parts = [] parts = []
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
@ -182,11 +199,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
else: else:
return None return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price final_price = response.usageMetadata.promptTokenCount * input_tokens_price
for i in response.usageMetadata.candidatesTokensDetails: if response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE: for i in response.usageMetadata.candidatesTokensDetails:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models if i.modality == Modality.IMAGE:
else: final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
final_price += output_text_tokens_price * i.tokenCount else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount: if response.usageMetadata.thoughtsTokenCount:
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
return final_price / 1_000_000.0 return final_price / 1_000_000.0
@ -337,8 +355,7 @@ class GeminiNode(IO.ComfyNode):
# Add other modal parts # Add other modal parts
if images is not None: if images is not None:
image_parts = create_image_parts(images) parts.extend(await create_image_parts(cls, images))
parts.extend(image_parts)
if audio is not None: if audio is not None:
parts.extend(cls.create_audio_parts(audio)) parts.extend(cls.create_audio_parts(audio))
if video is not None: if video is not None:
@ -363,29 +380,6 @@ class GeminiNode(IO.ComfyNode):
) )
output_text = get_text_from_response(response) output_text = get_text_from_response(response)
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text or "Empty response from Gemini model...") return IO.NodeOutput(output_text or "Empty response from Gemini model...")
@ -561,8 +555,7 @@ class GeminiImage(IO.ComfyNode):
image_config = GeminiImageConfig(aspectRatio=aspect_ratio) image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None: if images is not None:
image_parts = create_image_parts(images) parts.extend(await create_image_parts(cls, images))
parts.extend(image_parts)
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
@ -581,30 +574,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
class GeminiImage2(IO.ComfyNode): class GeminiImage2(IO.ComfyNode):
@ -645,7 +615,7 @@ class GeminiImage2(IO.ComfyNode):
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto", default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; " tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, generates a 1:1 square.", "if no image is provided, a 16:9 square is usually generated.",
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
@ -701,7 +671,7 @@ class GeminiImage2(IO.ComfyNode):
if images is not None: if images is not None:
if get_number_of_images(images) > 14: if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.") raise ValueError("The current maximum number of supported images is 14.")
parts.extend(create_image_parts(images)) parts.extend(await create_image_parts(cls, images))
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
@ -724,30 +694,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -1,15 +1,10 @@
from io import BytesIO from io import BytesIO
from typing import Optional, Union
import json
import os import os
import time
import uuid
from enum import Enum from enum import Enum
from inspect import cleandoc from inspect import cleandoc
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from server import PromptServer
import folder_paths import folder_paths
import base64 import base64
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
def create_input_message_contents( def create_input_message_contents(
cls, cls,
prompt: str, prompt: str,
image: Optional[torch.Tensor] = None, image: torch.Tensor | None = None,
files: Optional[list[InputFileContent]] = None, files: list[InputFileContent] | None = None,
) -> InputMessageContentList: ) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image.""" """Create a list of input message contents from prompt and optional image."""
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
InputTextContent(text=prompt, type="input_text"), InputTextContent(text=prompt, type="input_text"),
] ]
if image is not None: if image is not None:
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
prompt: str, prompt: str,
persist_context: bool = False, persist_context: bool = False,
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
images: Optional[torch.Tensor] = None, images: torch.Tensor | None = None,
files: Optional[list[InputFileContent]] = None, files: list[InputFileContent] | None = None,
advanced_options: Optional[CreateModelResponseProperties] = None, advanced_options: CreateModelResponseProperties | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
status_extractor=lambda response: response.status, status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"] completed_statuses=["incomplete", "completed"]
) )
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
# Update history
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text)
class OpenAIInputFiles(IO.ComfyNode): class OpenAIInputFiles(IO.ComfyNode):
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
def execute( def execute(
cls, cls,
truncation: bool, truncation: bool,
instructions: Optional[str] = None, instructions: str | None = None,
max_output_tokens: Optional[int] = None, max_output_tokens: int | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
""" """
Configure advanced options for the OpenAI Chat Node. Configure advanced options for the OpenAI Chat Node.

View File

@ -5,8 +5,7 @@ import aiohttp
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input.video_types import VideoInput from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
video: VideoInput, video: Input.Video,
upscaler_enabled: bool, upscaler_enabled: bool,
upscaler_model: str, upscaler_model: str,
upscaler_resolution: str, upscaler_resolution: str,
@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False: if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video) validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source() src_video_stream = video.get_stream_source()
target_width = src_width target_width = src_width
target_height = src_height target_height = src_height
@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode):
container="mp4", container="mp4",
size=get_fs_object_size(src_video_stream), size=get_fs_object_size(src_video_stream),
duration=int(duration_sec), duration=int(duration_sec),
frameCount=estimated_frames, frameCount=video.get_frame_count(),
frameRate=src_frame_rate, frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height), resolution=topaz_api.Resolution(width=src_width, height=src_height),
), ),

View File

@ -1,6 +1,7 @@
import base64 import base64
from io import BytesIO from io import BytesIO
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollResponse, VeoGenVidPollResponse,
VeoGenVidRequest, VeoGenVidRequest,
VeoGenVidResponse, VeoGenVidResponse,
VeoRequestInstance,
VeoRequestInstanceImage,
VeoRequestParameters,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
) )
class Veo3FirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="api node/video/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
IO.Int.Input(
"duration",
default=8,
min=4,
max=8,
step=2,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation",
),
IO.Image.Input("first_frame", tooltip="Start frame"),
IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input(
"model",
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
default="veo-3.1-fast-generate",
),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="Generate audio for the video.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
model: str,
generate_audio: bool,
):
model = MODELS_MAP[model]
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=[
VeoRequestInstance(
prompt=prompt,
image=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
),
lastFrame=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
),
),
],
parameters=VeoRequestParameters(
aspectRatio=aspect_ratio,
personGeneration="ALLOW",
durationSeconds=duration,
enhancePrompt=True, # cannot be False for Veo3
seed=seed,
generateAudio=generate_audio,
negativePrompt=negative_prompt,
resolution=resolution,
),
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
poll_interval=5.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class VeoExtension(ComfyExtension): class VeoExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
VeoVideoGenerationNode, VeoVideoGenerationNode,
Veo3VideoGenerationNode, Veo3VideoGenerationNode,
Veo3FirstLastFrameNode,
] ]

View File

@ -36,6 +36,7 @@ from .upload_helpers import (
upload_video_to_comfyapi, upload_video_to_comfyapi,
) )
from .validation_utils import ( from .validation_utils import (
get_image_dimensions,
get_number_of_images, get_number_of_images,
validate_aspect_ratio_string, validate_aspect_ratio_string,
validate_audio_duration, validate_audio_duration,
@ -82,6 +83,7 @@ __all__ = [
"trim_video", "trim_video",
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities
"get_image_dimensions",
"get_number_of_images", "get_number_of_images",
"validate_aspect_ratio_string", "validate_aspect_ratio_string",
"validate_audio_duration", "validate_audio_duration",

View File

@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional, Union from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
image: torch.Tensor, image: torch.Tensor,
*, *,
max_images: int = 8, max_images: int = 8,
mime_type: Optional[str] = None, mime_type: str | None = None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
show_batch_index: bool = True,
) -> list[str]: ) -> list[str]:
""" """
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1 batch_len = image.shape[0] if is_batch else 1
num_to_upload = min(batch_len, max_images)
batch_start_ts = time.monotonic()
for idx in range(min(batch_len, max_images)): for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type) img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url) download_urls.append(url)
return download_urls return download_urls
@ -126,8 +134,9 @@ async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
file_bytes_io: BytesIO, file_bytes_io: BytesIO,
filename: str, filename: str,
upload_mime_type: Optional[str], upload_mime_type: str | None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
progress_origin_ts: float | None = None,
) -> str: ) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL.""" """Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None: if upload_mime_type is None:
@ -148,6 +157,7 @@ async def upload_file_to_comfyapi(
file_bytes_io, file_bytes_io,
content_type=upload_mime_type, content_type=upload_mime_type,
wait_label=wait_label, wait_label=wait_label,
progress_origin_ts=progress_origin_ts,
) )
return create_resp.download_url return create_resp.download_url
@ -155,27 +165,18 @@ async def upload_file_to_comfyapi(
async def upload_file( async def upload_file(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
upload_url: str, upload_url: str,
file: Union[BytesIO, str], file: BytesIO | str,
*, *,
content_type: Optional[str] = None, content_type: str | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: Optional[str] = None, wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None: ) -> None:
""" """
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises: Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
""" """
@ -198,7 +199,7 @@ async def upload_file(
attempt = 0 attempt = 0
delay = retry_delay delay = retry_delay
start_ts = time.monotonic() start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8] op_uuid = uuid.uuid4().hex[:8]
while True: while True:
attempt += 1 attempt += 1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,10 @@ import node_helpers
import comfy.utils import comfy.utils
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import math
import nodes
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
encode = execute # TODO: remove encode = execute # TODO: remove
class EmptyFlux2LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyFlux2LatentImage",
display_name="Empty Flux 2 Latent",
category="latent",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class FluxGuidance(io.ComfyNode): class FluxGuidance(io.ComfyNode):
@classmethod @classmethod
@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
append = execute # TODO: remove append = execute # TODO: remove
def generalized_time_snr_shift(t, mu: float, sigma: float):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps
class Flux2Scheduler(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Flux2Scheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=4096),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Sigmas.Output(),
],
)
@classmethod
def execute(cls, steps, width, height) -> io.NodeOutput:
seq_len = (width * height / (16 * 16))
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class FluxExtension(ComfyExtension): class FluxExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension):
FluxDisableGuidance, FluxDisableGuidance,
FluxKontextImageScale, FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod, FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
] ]

View File

@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile
from pathlib import Path from pathlib import Path
from PIL import Image
import numpy as np
import uuid
def normalize_path(path): def normalize_path(path):
return path.replace('\\', '/') return path.replace('\\', '/')
@ -34,58 +38,6 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
@ -120,7 +72,8 @@ class Preview3D():
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
}, },
"optional": { "optional": {
"camera_info": ("LOAD3D_CAMERA", {}) "camera_info": ("LOAD3D_CAMERA", {}),
"bg_image": ("IMAGE", {})
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
@ -133,50 +86,33 @@ class Preview3D():
def process(self, model_file, **kwargs): def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None) camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
bg_image_path = None
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
bg_image_path = f"temp/{filename}"
return { return {
"ui": { "ui": {
"result": [model_file, camera_info] "result": [model_file, camera_info, bg_image_path]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
} }
} }
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Load3D": Load3D, "Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D, "Preview3D": Preview3D,
"Preview3DAnimation": Preview3DAnimation
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D", "Load3D": "Load 3D & Animation",
"Load3DAnimation": "Load 3D - Animation", "Preview3D": "Preview 3D & Animation",
"Preview3D": "Preview 3D",
"Preview3DAnimation": "Preview 3D - Animation"
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.71" __version__ = "0.3.75"

View File

@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer):
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None: if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
if self.latent_rgb_factors_reshape is not None:
x0 = self.latent_rgb_factors_reshape(x0)
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None: if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
@ -85,7 +88,7 @@ def get_previewer(device, latent_format):
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
return previewer return previewer
def prepare_callback(model, steps, x0_output_dict=None): def prepare_callback(model, steps, x0_output_dict=None):

View File

@ -929,7 +929,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes():
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_train.py", "nodes_train.py",
"nodes_dataset.py",
"nodes_sag.py", "nodes_sag.py",
"nodes_perpneg.py", "nodes_perpneg.py",
"nodes_stable3d.py", "nodes_stable3d.py",

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.71" version = "0.3.75"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.30.6 comfyui-frontend-package==1.32.9
comfyui-workflow-templates==0.7.9 comfyui-workflow-templates==0.7.20
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde
@ -7,7 +7,7 @@ torchvision
torchaudio torchaudio
numpy>=1.25.0 numpy>=1.25.0
einops einops
transformers>=4.37.2 transformers>=4.50.3
tokenizers>=0.13.3 tokenizers>=0.13.3
sentencepiece sentencepiece
safetensors>=0.4.2 safetensors>=0.4.2

View File

@ -174,7 +174,7 @@ def create_block_external_middleware():
else: else:
response = await handler(request) response = await handler(request)
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
return response return response
return block_external_middleware return block_external_middleware

View File

@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self): def test_all_layers_standard(self):
"""Test that model with no quantization works normally""" """Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model # Create model
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops({}))
# Initialize weights manually # Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision # Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Create model and load state dict (strict=False because custom loading pops keys) # Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict1, strict=False) model.load_state_dict(state_dict1, strict=False)
# Save state dict # Save state dict
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA) # Add a weight function (simulating LoRA)
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict # Create state dict
state_dict = { state_dict = {
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)

View File

@ -0,0 +1,153 @@
"""
Tests for public ComfyAPI and ComfyAPISync functions.
These tests verify that the public API methods work correctly in both sync and async contexts,
ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly
handles string annotations from 'from __future__ import annotations'.
"""
import pytest
import time
import subprocess
import torch
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient
@pytest.mark.execution
class TestPublicAPI:
"""Test suite for public ComfyAPI and ComfyAPISync methods."""
@fixture(scope="class", autouse=True)
def _server(self, args_pytest):
"""Start ComfyUI server for testing."""
pargs = [
'python', 'main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
"""Create shared client with connection retry."""
client = ComfyClient()
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
break
except ConnectionRefusedError:
if i == n_tries - 1:
raise
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
"""Set test name for each test."""
shared_client.set_test_name(f"public_api[{request.node.name}]")
yield shared_client
@fixture
def builder(self, request):
"""Create GraphBuilder for each test."""
yield GraphBuilder(prefix=request.node.name)
def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
"""Test that TestSyncProgressUpdate executes without errors.
This test validates that api_sync.execution.set_progress() works correctly,
which is the primary code path fixed by adding get_type_hints() to async_to_sync.py.
"""
g = builder
image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
# Use TestSyncProgressUpdate with short sleep
progress_node = g.node("TestSyncProgressUpdate",
value=image.out(0),
sleep_seconds=0.5)
output = g.node("SaveImage", images=progress_node.out(0))
# Execute workflow
result = client.run(g)
# Verify execution
assert result.did_run(progress_node), "Progress node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify output
images = result.get_images(output)
assert len(images) == 1, "Should have produced 1 image"
def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder):
"""Test that TestAsyncProgressUpdate executes without errors.
This test validates that await api.execution.set_progress() works correctly
in async contexts.
"""
g = builder
image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
# Use TestAsyncProgressUpdate with short sleep
progress_node = g.node("TestAsyncProgressUpdate",
value=image.out(0),
sleep_seconds=0.5)
output = g.node("SaveImage", images=progress_node.out(0))
# Execute workflow
result = client.run(g)
# Verify execution
assert result.did_run(progress_node), "Async progress node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify output
images = result.get_images(output)
assert len(images) == 1, "Should have produced 1 image"
def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder):
"""Test both sync and async progress updates in same workflow.
This test ensures that both ComfyAPISync and ComfyAPI can coexist and work
correctly in the same workflow execution.
"""
g = builder
image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
# Use both types of progress nodes
sync_progress = g.node("TestSyncProgressUpdate",
value=image1.out(0),
sleep_seconds=0.3)
async_progress = g.node("TestAsyncProgressUpdate",
value=image2.out(0),
sleep_seconds=0.3)
# Create outputs
output1 = g.node("SaveImage", images=sync_progress.out(0))
output2 = g.node("SaveImage", images=async_progress.out(0))
# Execute workflow
result = client.run(g)
# Both should execute successfully
assert result.did_run(sync_progress), "Sync progress node should have executed"
assert result.did_run(async_progress), "Async progress node should have executed"
assert result.did_run(output1), "First output node should have executed"
assert result.did_run(output2), "Second output node should have executed"
# Verify outputs
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images1) == 1, "Should have produced 1 image from sync node"
assert len(images2) == 1, "Should have produced 1 image from async node"