diff --git a/README.md b/README.md index 28beec427..b9300ab07 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) + - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 204fc048d..e98c7d6d8 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -178,6 +178,15 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 67bf70eb1..a72f8cc47 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -179,7 +179,10 @@ class Chroma(nn.Module): pe = self.pe_embedder(ids) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -222,7 +225,10 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 23150a712..2472ab79c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module): class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -98,11 +98,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module): self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): @@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module): mlp_ratio: float = 4.0, qk_scale: float = None, modulation=True, + mlp_silu_act=False, + bias=True, dtype=None, device=None, operations=None @@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module): self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") if modulation: self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) else: @@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module): else: mod = vec - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) del qkv @@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f202..1a24e6d95 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation ) @dataclass @@ -33,6 +34,11 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 class Flux(nn.Module): @@ -58,13 +64,17 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -73,6 +83,9 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -81,13 +94,30 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -103,9 +133,6 @@ class Flux(nn.Module): attn_mask: Tensor = None, ) -> Tensor: - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -118,9 +145,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) @@ -177,6 +212,9 @@ class Flux(nn.Module): img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): @@ -207,7 +245,7 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -234,10 +272,10 @@ class Flux(nn.Module): h_offset += rope_options.get("shift_y", 0.0) w_offset += rope_options.get("shift_x", 0.0) - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -259,10 +297,10 @@ class Flux(nn.Module): h = 0 w = 0 index = 0 - ref_latents_method = kwargs.get("ref_latents_method", "offset") + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: if ref_latents_method == "index": - index += 1 + index += self.params.ref_index_scale h_offset = 0 w_offset = 0 elif ref_latents_method == "uxo": @@ -286,7 +324,11 @@ class Flux(nn.Module): img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.axes_dim) == 4: # Flux 2 + txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f75c6e0e1..2749c53f5 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module): attn_mask = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module): img = torch.cat((img, txt), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b4494a51d..c8643eb82 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension @@ -31,6 +32,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + out_bias: bool = False, operation_settings={}, ): """ @@ -59,7 +61,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -70,35 +72,6 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( self, x: torch.Tensor, @@ -134,8 +107,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: @@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module): norm_eps: float, qk_norm: bool, modulation=True, + z_image_modulation=False, + attn_out_bias=False, operation_settings={}, ) -> None: """ @@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( self, @@ -323,7 +308,7 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}): super().__init__() self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, @@ -340,10 +325,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -373,12 +363,16 @@ class NextDiT(nn.Module): n_heads: int = 32, n_kv_heads: Optional[int] = None, multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, + ffn_dim_multiplier: float = 4.0, norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, image_model=None, device=None, dtype=None, @@ -390,6 +384,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -411,6 +407,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -434,7 +431,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -457,18 +454,24 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -503,108 +506,42 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) - - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) + padded_img_mask = None for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): @@ -627,7 +564,7 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 611d36a1b..4f50810dc 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.modules.ema import LitEma import comfy.ops +from einops import rearrange +import comfy.model_management class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): @@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params @@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine): z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 42f406f1a..0dc8fe789 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 427ea19c1..8c75670cd 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/model_base.py b/comfy/model_base.py index e14b552c5..cc21b1de9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,12 +898,13 @@ class Flux(BaseModel): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -928,6 +929,16 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -1103,9 +1114,13 @@ class Lumina2(BaseModel): if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0131ca25a..7afe4a798 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["vec_in_dim"] = None + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 @@ -388,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/model_management.py b/comfy/model_management.py index a21df54b3..a9327ac80 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) +PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if type(tensor) is not torch.nn.parameter.Parameter: + if type(tensor).__name__ not in PINNING_ALLOWED_TYPES: return False if not is_device_cpu(tensor.device): @@ -1124,6 +1125,9 @@ def pin_memory(tensor): return False ptr = tensor.data_ptr() + if ptr == 0: + return False + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cf1b0d441..6551ced5a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -231,7 +231,6 @@ class ModelPatcher: self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update @@ -286,7 +285,7 @@ class ModelPatcher: return self.model.lowvram_patch_counter def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..785aa1c9f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -540,115 +540,118 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== from .quant_ops import QuantizedTensor, QUANT_ALGOS -class MixedPrecisionOps(disable_weight_init): - _layer_quant_config = {} - _compute_dtype = torch.bfloat16 - class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__() +def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _layer_quant_config = layer_quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm - self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) - else: - self.register_parameter("bias", None) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} - self.tensor_class = None + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) - def reset_parameters(self): - return None + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def reset_parameters(self): + return None - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): - manually_loaded_keys = [weight_key] + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") - if layer_name not in MixedPrecisionOps._layer_quant_config: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") + manually_loaded_keys = [weight_key] - qconfig = QUANT_ALGOS[quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - weight_scale_key = f"{prefix}weight_scale" - layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } - if layout_params['scale'] is not None: - manually_loaded_keys.append(weight_scale_key) + qconfig = QUANT_ALGOS[quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), - requires_grad=False - ) + weight_scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(weight_scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(weight_scale_key) - for param_name in qconfig["parameters"]: - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + requires_grad=False + ) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def _forward(self, input, weight, bias): - return torch.nn.functional.linear(input, weight, bias) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._forward(input, weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) - def forward(self, input, *args, **kwargs): - run_every_op() + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) - if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + def forward(self, input, *args, **kwargs): + run_every_op() + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - MixedPrecisionOps._compute_dtype = compute_dtype - logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return MixedPrecisionOps + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece..0c16bcf8d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor): new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self): + return self._qdata.is_contiguous() # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) @@ -338,6 +346,18 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True @@ -385,8 +405,8 @@ class TensorCoreFP8Layout(QuantizedLayout): tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - # lp_amax = torch.finfo(dtype).max - # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada..350fae92b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -52,6 +52,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image import comfy.model_patcher import comfy.lora @@ -356,7 +357,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32: + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -382,6 +383,17 @@ class VAE: self.upscale_ratio = 4 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + if 'post_quant_conv.weight' in sd: self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) else: @@ -917,7 +929,12 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + if metadata is not None: + quant_metadata = metadata.get("_quantization_metadata", None) + if quant_metadata is not None: + sd["_quantization_metadata"] = quant_metadata + clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -935,6 +952,10 @@ class TEModel(Enum): QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -967,6 +988,15 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.QWEN3_4B + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -1081,6 +1111,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -1142,6 +1179,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: + if "_quantization_metadata" in c: + c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7..0fc9ab3db 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() - assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") @@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): operations = model_options.get("custom_operations", None) scaled_fp8 = None + quantization_metadata = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + layer_quant_config = None + if quantization_metadata is not None: + layer_quant_config = json.loads(quantization_metadata).get("layers", None) + + if layer_quant_config is not None: + operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) @@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if self.layer == "all": + if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - if self.layer == "all": + if isinstance(self.layer, list): + intermediate_output = self.layer + elif self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2e64b85e8..af8120400 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image from . import supported_models_base from . import latent_formats @@ -741,6 +742,37 @@ class FluxSchnell(Flux): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE): "shift": 6.0, } - memory_usage_factor = 1.2 + memory_usage_factor = 1.4 unet_extra_config = {} latent_format = latent_formats.Flux @@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) +class ZImage(Lumina2): + unet_config = { + "image_model": "lumina2", + "dim": 3840, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + memory_usage_factor = 1.7 + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1422,6 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] + models += [SVD_img2vid] diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d61ef6668..99f4812bb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,10 +1,13 @@ from comfy import sd1_clip import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip +import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast import torch import os +import json +import base64 class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + try: + from transformers.integrations.mistral import MistralConverter + except ModuleNotFoundError: + from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + max_vocab -= len(mistral_vocab["special_tokens"]) + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + +def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Flux2TEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 557094f49..0110517bb 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""): if scaled_fp8_key in state_dict: out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + if "_quantization_metadata" in state_dict: + out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + return out diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index feb44bbb0..cd4b5f76c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -34,6 +34,28 @@ class Llama2Config: rope_scale = None final_norm: bool = True +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -56,6 +78,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -412,8 +456,12 @@ class Llama2_(nn.Module): intermediate = None all_intermediate = None + only_layers = None if intermediate_output is not None: - if intermediate_output == "all": + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": all_intermediate = [] intermediate_output = None elif intermediate_output < 0: @@ -421,7 +469,8 @@ class Llama2_(nn.Module): for i, layer in enumerate(self.layers): if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -435,7 +484,8 @@ class Llama2_(nn.Module): x = self.norm(x) if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) @@ -465,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -474,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..bb9273b20 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,48 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_ diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index a335df4d0..87c81d73a 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from fractions import Fraction from typing import Optional, Union, IO import io import av @@ -72,6 +73,33 @@ class VideoInput(ABC): frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video. + + Default implementation uses :meth:`get_components`, which may require + loading all frames into memory. File-based implementations should + override this method and use container/stream metadata instead. + + Returns: + Total number of frames as an integer. + """ + return int(self.get_components().images.shape[0]) + + def get_frame_rate(self) -> Fraction: + """ + Returns the frame rate of the video. + + Default implementation materializes the video into memory via + `get_components()`. Subclasses that can inspect the underlying + container (e.g. `VideoFromFile`) should override this with a more + efficient implementation. + + Returns: + Frame rate as a Fraction. + """ + return self.get_components().frame_rate + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index f646504c8..bde37f90a 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -121,6 +121,71 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video without materializing them as + torch tensors. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # 1. Prefer the frames field if available + if video_stream.frames and video_stream.frames > 0: + return int(video_stream.frames) + + # 2. Try to estimate from duration and average_rate using only metadata + if container.duration is not None and video_stream.average_rate: + duration_seconds = float(container.duration / av.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + if ( + getattr(video_stream, "duration", None) is not None + and getattr(video_stream, "time_base", None) is not None + and video_stream.average_rate + ): + duration_seconds = float(video_stream.duration * video_stream.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + # 3. Last resort: decode frames and count them (streaming) + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + + if frame_count == 0: + raise ValueError(f"Could not determine frame count for file '{self.__file}'") + return frame_count + + def get_frame_rate(self) -> Fraction: + """ + Returns the average frame rate of the video using container metadata + without decoding all frames. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # Preferred: use PyAV's average_rate (usually already a Fraction-like) + if video_stream.average_rate: + return Fraction(video_stream.average_rate) + + # Fallback: estimate from frames + duration if available + if video_stream.frames and container.duration: + duration_seconds = float(container.duration / av.time_base) + if duration_seconds > 0: + return Fraction(video_stream.frames / duration_seconds).limit_denominator() + + # Last resort: match get_components_internal default + return Fraction(1) + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). @@ -238,6 +303,13 @@ class VideoFromFile(VideoInput): packet.stream = stream_map[packet.stream] output_container.mux(packet) + def _get_first_video_stream(self, container: InputContainer): + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise ValueError(f"No video stream found in file '{self.__file}'") + return video_stream + + class VideoFromComponents(VideoInput): """ Class representing video input from tensors. diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0fc8c0607..d8d3557b3 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class Flux2ProGenerateRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + seed: int | None = Field(None) + prompt_upsampling: bool | None = Field(None) + input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + safety_tolerance: int | None = Field( + 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 + ) + output_format: str | None = Field( + "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." + ) + + class BFLFluxKontextProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for what you wannt to edit.') input_image: Optional[str] = Field(None, description='Image to edit in base64 format') @@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') + id: str = Field(..., description="The unique identifier for the generation task.") + polling_url: str = Field(..., description="URL to poll for the generation result.") + cost: float | None = Field(None, description="Price in cents") class BFLStatus(str, Enum): diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 710f173f1..d34590d28 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel): maxOutputTokens: int | None = Field(None, ge=16, le=8192) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) - temperature: float | None = Field(1, ge=0.0, le=2.0) - topK: int | None = Field(40, ge=1) - topP: float | None = Field(0.95, ge=0.0, le=1.0) + temperature: float | None = Field(None, ge=0.0, le=2.0) + topK: int | None = Field(None, ge=1) + topP: float | None = Field(None, ge=0.0, le=1.0) class GeminiImageConfig(BaseModel): diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 1740fb377..8826dea0c 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,7 +1,7 @@ from inspect import cleandoc -from typing import Optional import torch +from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension @@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, BFLStatus, + Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, + get_number_of_images, poll_op, resize_mask_to_image, sync_op, @@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: Optional[torch.Tensor] = None, + image_prompt: torch.Tensor | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor] = None, + input_image: torch.Tensor | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and resolution. - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProImageNode", - display_name="Flux 1.1 [pro] Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Int.Input( - "width", - default=1024, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "height", - default=768, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - IO.Image.Input( - "image_prompt", - optional=True, - ), - # "image_prompt_strength": ( - # IO.FLOAT, - # { - # "default": 0.1, - # "min": 0.0, - # "max": 1.0, - # "step": 0.01, - # "tooltip": "Blend between the prompt and the image prompt.", - # }, - # ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - prompt: str, - prompt_upsampling, - width: int, - height: int, - seed=0, - image_prompt=None, - # image_prompt_strength=0.1, - ) -> IO.NodeOutput: - image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) - initial_response = await sync_op( - cls, - ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1/generate", - method="POST", - ), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - width=width, - height=height, - seed=seed, - image_prompt=image_prompt, - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) - - class FluxProExpandNode(IO.ComfyNode): """ Outpaints image based on prompt. @@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2ProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Flux2ProImageNode", + display_name="Flux.2 [pro] Image", + category="api node/image/BFL", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or edit", + ), + IO.Int.Input( + "width", + default=1024, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "height", + default=768, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", + ), + IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + width: int, + height: int, + seed: int, + prompt_upsampling: bool, + images: torch.Tensor | None = None, + ) -> IO.NodeOutput: + reference_images = {} + if images is not None: + if get_number_of_images(images) > 9: + raise ValueError("The current maximum number of supported images is 9.") + for image_index in range(images.shape[0]): + key_name = f"input_image_{image_index + 1}" if image_index else "input_image" + reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=Flux2ProGenerateRequest( + prompt=prompt, + width=width, + height=height, + seed=seed, + prompt_upsampling=prompt_upsampling, + **reference_images, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FluxProUltraImageNode, - # FluxProImageNode, FluxKontextProImageNode, FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + Flux2ProImageNode, ] diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index be752c885..938a20f84 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera List of response parts matching the requested type. """ if response.candidates is None: - if response.promptFeedback.blockReason: + if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" ) - raise NotImplementedError( - "Gemini returned no response candidates. " - "Please report to ComfyUI repository with the example of workflow to reproduce this." + raise ValueError( + "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " + "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] for part in response.candidates[0].content.parts: @@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price - for i in response.usageMetadata.candidatesTokensDetails: - if i.modality == Modality.IMAGE: - final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models - else: - final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.candidatesTokensDetails: + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount if response.usageMetadata.thoughtsTokenCount: final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount return final_price / 1_000_000.0 @@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode): options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], default="auto", tooltip="If set to 'auto', matches your input image's aspect ratio; " - "if no image is provided, generates a 1:1 square.", + "if no image is provided, a 16:9 square is usually generated.", ), IO.Combo.Input( "resolution", diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 79c7bf43d..f522756e5 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,8 +5,7 @@ import aiohttp import torch from typing_extensions import override -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.util import ( ApiEndpoint, @@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, upscaler_enabled: bool, upscaler_model: str, upscaler_resolution: str, @@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if upscaler_enabled is False and interpolation_enabled is False: raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") - src_width, src_height = video.get_dimensions() - video_components = video.get_components() - src_frame_rate = int(video_components.frame_rate) - duration_sec = video.get_duration() - estimated_frames = int(duration_sec * src_frame_rate) validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() src_video_stream = video.get_stream_source() target_width = src_width target_height = src_height @@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode): container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), - frameCount=estimated_frames, + frameCount=video.get_frame_count(), frameRate=src_frame_rate, resolution=topaz_api.Resolution(width=src_width, height=src_height), ), diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 21013b591..80292fb3c 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -36,6 +36,7 @@ from .upload_helpers import ( upload_video_to_comfyapi, ) from .validation_utils import ( + get_image_dimensions, get_number_of_images, validate_aspect_ratio_string, validate_audio_duration, @@ -82,6 +83,7 @@ __all__ = [ "trim_video", "video_to_base64_string", # Validation utilities + "get_image_dimensions", "get_number_of_images", "validate_aspect_ratio_string", "validate_audio_duration", diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index ce1b2e89f..d9c4bba81 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -2,7 +2,10 @@ import node_helpers import comfy.utils from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import math +import nodes class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode): encode = execute # TODO: remove +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) class FluxGuidance(io.ComfyNode): @classmethod @@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): append = execute # TODO: remove +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + class FluxExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension): FluxDisableGuidance, FluxKontextImageScale, FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, ] diff --git a/comfyui_version.py b/comfyui_version.py index b4655d553..f8818838e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.71" +__version__ = "0.3.73" diff --git a/nodes.py b/nodes.py index 6f6a34fef..902468895 100644 --- a/nodes.py +++ b/nodes.py @@ -932,7 +932,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), diff --git a/pyproject.toml b/pyproject.toml index 280dbaf53..7e4bac12d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.71" +version = "0.3.73" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 33ba0ce15..650cfbbbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.30.6 -comfyui-workflow-templates==0.7.9 +comfyui-workflow-templates==0.7.20 comfyui-embedded-docs==0.3.1 comfyui_manager==4.0.3b1 torch @@ -8,7 +8,7 @@ torchvision torchaudio numpy>=1.25.0 einops -transformers>=4.37.2 +transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index f8d1fd04e..63361309f 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): """Test that model with no quantization works normally""" - # Configure no quantization - ops.MixedPrecisionOps._layer_quant_config = {} - # Create model - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops({})) # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) @@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict state_dict = { @@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False)