Merge branch 'master' into v3-improvements
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Jedrzej Kosinski 2025-12-19 16:18:31 -08:00
commit 847c278790
6 changed files with 122 additions and 64 deletions

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding( self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations operations=operations
) )
def forward(self, timestep, hidden_states): self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb return timesteps_emb
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24, num_attention_heads: int = 24,
joint_attention_dim: int = 3584, joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index", default_ref_method="index",
image_model=None, image_model=None,
final_layer=True, final_layer=True,
use_additional_t_cond=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations operations=operations
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device) img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward( def _forward(
self, self,
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps, timesteps,
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None, ref_latents=None,
additional_t_cond=None,
transformer_options={}, transformer_options={},
control=None, control=None,
**kwargs **kwargs
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
index = 0 index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method) ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero" timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if index_ref_method: if index_ref_method:
index += 1 index += 1
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else: else:
index = 1 index = 1
h_offset = 0 h_offset = 0
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None: temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero" dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

View File

@ -26,6 +26,7 @@ import importlib
import platform import platform
import weakref import weakref
import gc import gc
import os
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -333,13 +334,15 @@ except:
SUPPORT_FP8_OPS = args.supports_fp8_compute SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
try: try:
if is_amd(): if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try: try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
) )
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: Input.Image):
""" """
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
""" """
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode", node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image", display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: torch.Tensor | None = None, image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: torch.Tensor | None = None, input_image: Input.Image | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode): class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode" NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode", node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image", display_name="Flux.1 Expand Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Outpaints image based on prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
top: int, top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode", node_id="FluxProFillNode",
display_name="Flux.1 Fill Image", display_name="Flux.1 Fill Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Inpaints image based on mask and prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
mask: torch.Tensor, mask: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
steps: int, steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="Flux2ProImageNode", node_id=cls.NODE_ID,
display_name="Flux.2 [pro] Image", display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=True,
tooltip="Whether to perform upsampling on the prompt. " tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, " "If active, automatically modifies the prompt for more creative generation.",
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
hidden=[ hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int, height: int,
seed: int, seed: int,
prompt_upsampling: bool, prompt_upsampling: bool,
images: torch.Tensor | None = None, images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
reference_images = {} reference_images = {}
if images is not None: if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse, response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest( data=Flux2ProGenerateRequest(
prompt=prompt, prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode,
] ]

View File

@ -5,6 +5,7 @@ import nodes
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import logging import logging
import math
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode):
samples_out["samples"] = torch.narrow(s1, dim, index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return io.NodeOutput(samples_out) return io.NodeOutput(samples_out)
class LatentCutToBatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LatentCutToBatch",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
samples_out = samples.copy()
s1 = samples["samples"]
if "x" in dim:
dim = s1.ndim - 1
elif "y" in dim:
dim = s1.ndim - 2
elif "t" in dim:
dim = s1.ndim - 3
if dim < 2:
return io.NodeOutput(samples)
s = s1.movedim(dim, 1)
if s.shape[1] < slice_size:
slice_size = s.shape[1]
elif s.shape[1] % slice_size != 0:
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
new_shape = [-1, slice_size] + list(s.shape[2:])
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
return io.NodeOutput(samples_out)
class LatentBatch(io.ComfyNode): class LatentBatch(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension):
LatentInterpolate, LatentInterpolate,
LatentConcat, LatentConcat,
LatentCut, LatentCut,
LatentCutToBatch,
LatentBatch, LatentBatch,
LatentBatchSeedBehavior, LatentBatchSeedBehavior,
LatentApplyOperation, LatentApplyOperation,

View File

@ -1 +1 @@
comfyui_manager==4.0.3b5 comfyui_manager==4.0.3b7