mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-03 10:40:51 +08:00
Merge branch 'master' into v3-improvements
This commit is contained in:
commit
847c278790
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||||
self.timestep_embedder = TimestepEmbedding(
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
|||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, timestep, hidden_states):
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||||
|
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
if addition_t_cond is None:
|
||||||
|
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||||
|
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||||
|
|
||||||
return timesteps_emb
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
num_attention_heads: int = 24,
|
num_attention_heads: int = 24,
|
||||||
joint_attention_dim: int = 3584,
|
joint_attention_dim: int = 3584,
|
||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
default_ref_method="index",
|
default_ref_method="index",
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
|
use_additional_t_cond=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim,
|
embedding_dim=self.inner_dim,
|
||||||
pooled_projection_dim=pooled_projection_dim,
|
pooled_projection_dim=pooled_projection_dim,
|
||||||
|
use_additional_t_cond=use_additional_t_cond,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
|
t_len = t
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
if t_len > 1:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||||
|
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||||
|
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
ref_latents=None,
|
ref_latents=None,
|
||||||
|
additional_t_cond=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
control=None,
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
index = 0
|
index = 0
|
||||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
|
negative_ref_method = ref_method == "negative_index"
|
||||||
timestep_zero = ref_method == "index_timestep_zero"
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if index_ref_method:
|
||||||
index += 1
|
index += 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
|
elif negative_ref_method:
|
||||||
|
index -= 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
else:
|
else:
|
||||||
index = 1
|
index = 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
if guidance is not None:
|
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||||
guidance = guidance * 1000
|
|
||||||
|
|
||||||
temb = (
|
|
||||||
self.time_text_embed(timestep, hidden_states)
|
|
||||||
if guidance is None
|
|
||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
|
|||||||
@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||||
|
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||||
|
dit_config["use_additional_t_cond"] = True
|
||||||
|
dit_config["default_ref_method"] = "negative_index"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import importlib
|
|||||||
import platform
|
import platform
|
||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -333,13 +334,15 @@ except:
|
|||||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
|
|
||||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||||
|
AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
from inspect import cleandoc
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis.bfl_api import (
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
BFLFluxExpandImageRequest,
|
BFLFluxExpandImageRequest,
|
||||||
BFLFluxFillImageRequest,
|
BFLFluxFillImageRequest,
|
||||||
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_mask_to_image(mask: torch.Tensor):
|
def convert_mask_to_image(mask: Input.Image):
|
||||||
"""
|
"""
|
||||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||||
"""
|
"""
|
||||||
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
class FluxProUltraImageNode(IO.ComfyNode):
|
class FluxProUltraImageNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
node_id="FluxProUltraImageNode",
|
node_id="FluxProUltraImageNode",
|
||||||
display_name="Flux 1.1 [pro] Ultra Image",
|
display_name="Flux 1.1 [pro] Ultra Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
prompt_upsampling: bool = False,
|
prompt_upsampling: bool = False,
|
||||||
raw: bool = False,
|
raw: bool = False,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
image_prompt: torch.Tensor | None = None,
|
image_prompt: Input.Image | None = None,
|
||||||
image_prompt_strength: float = 0.1,
|
image_prompt_strength: float = 0.1,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if image_prompt is None:
|
if image_prompt is None:
|
||||||
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxKontextProImageNode(IO.ComfyNode):
|
class FluxKontextProImageNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
node_id=cls.NODE_ID,
|
node_id=cls.NODE_ID,
|
||||||
display_name=cls.DISPLAY_NAME,
|
display_name=cls.DISPLAY_NAME,
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
guidance: float,
|
guidance: float,
|
||||||
steps: int,
|
steps: int,
|
||||||
input_image: torch.Tensor | None = None,
|
input_image: Input.Image | None = None,
|
||||||
seed=0,
|
seed=0,
|
||||||
prompt_upsampling=False,
|
prompt_upsampling=False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||||
"""
|
|
||||||
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
|
||||||
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
|
||||||
NODE_ID = "FluxKontextMaxImageNode"
|
NODE_ID = "FluxKontextMaxImageNode"
|
||||||
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
||||||
|
|
||||||
|
|
||||||
class FluxProExpandNode(IO.ComfyNode):
|
class FluxProExpandNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Outpaints image based on prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
node_id="FluxProExpandNode",
|
node_id="FluxProExpandNode",
|
||||||
display_name="Flux.1 Expand Image",
|
display_name="Flux.1 Expand Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Outpaints image based on prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
top: int,
|
top: int,
|
||||||
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class FluxProFillNode(IO.ComfyNode):
|
class FluxProFillNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Inpaints image based on mask and prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
node_id="FluxProFillNode",
|
node_id="FluxProFillNode",
|
||||||
display_name="Flux.1 Fill Image",
|
display_name="Flux.1 Fill Image",
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Inpaints image based on mask and prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.Mask.Input("mask"),
|
IO.Mask.Input("mask"),
|
||||||
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
mask: torch.Tensor,
|
mask: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
steps: int,
|
steps: int,
|
||||||
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
|
|||||||
|
|
||||||
class Flux2ProImageNode(IO.ComfyNode):
|
class Flux2ProImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
NODE_ID = "Flux2ProImageNode"
|
||||||
|
DISPLAY_NAME = "Flux.2 [pro] Image"
|
||||||
|
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Flux2ProImageNode",
|
node_id=cls.NODE_ID,
|
||||||
display_name="Flux.2 [pro] Image",
|
display_name=cls.DISPLAY_NAME,
|
||||||
category="api node/image/BFL",
|
category="api node/image/BFL",
|
||||||
description="Generates images synchronously based on prompt and resolution.",
|
description="Generates images synchronously based on prompt and resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_upsampling",
|
"prompt_upsampling",
|
||||||
default=False,
|
default=True,
|
||||||
tooltip="Whether to perform upsampling on the prompt. "
|
tooltip="Whether to perform upsampling on the prompt. "
|
||||||
"If active, automatically modifies the prompt for more creative generation, "
|
"If active, automatically modifies the prompt for more creative generation.",
|
||||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
|
||||||
),
|
),
|
||||||
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
|
IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
|
||||||
],
|
],
|
||||||
outputs=[IO.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
height: int,
|
height: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
prompt_upsampling: bool,
|
prompt_upsampling: bool,
|
||||||
images: torch.Tensor | None = None,
|
images: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
reference_images = {}
|
reference_images = {}
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
|
ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
|
||||||
response_model=BFLFluxProGenerateResponse,
|
response_model=BFLFluxProGenerateResponse,
|
||||||
data=Flux2ProGenerateRequest(
|
data=Flux2ProGenerateRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2MaxImageNode(Flux2ProImageNode):
|
||||||
|
|
||||||
|
NODE_ID = "Flux2MaxImageNode"
|
||||||
|
DISPLAY_NAME = "Flux.2 [max] Image"
|
||||||
|
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
|
||||||
|
|
||||||
|
|
||||||
class BFLExtension(ComfyExtension):
|
class BFLExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
|
|||||||
FluxProExpandNode,
|
FluxProExpandNode,
|
||||||
FluxProFillNode,
|
FluxProFillNode,
|
||||||
Flux2ProImageNode,
|
Flux2ProImageNode,
|
||||||
|
Flux2MaxImageNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import nodes
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
@ -207,6 +208,47 @@ class LatentCut(io.ComfyNode):
|
|||||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||||
return io.NodeOutput(samples_out)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
|
class LatentCutToBatch(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LatentCutToBatch",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("dim", options=["t", "x", "y"]),
|
||||||
|
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
|
||||||
|
samples_out = samples.copy()
|
||||||
|
|
||||||
|
s1 = samples["samples"]
|
||||||
|
|
||||||
|
if "x" in dim:
|
||||||
|
dim = s1.ndim - 1
|
||||||
|
elif "y" in dim:
|
||||||
|
dim = s1.ndim - 2
|
||||||
|
elif "t" in dim:
|
||||||
|
dim = s1.ndim - 3
|
||||||
|
|
||||||
|
if dim < 2:
|
||||||
|
return io.NodeOutput(samples)
|
||||||
|
|
||||||
|
s = s1.movedim(dim, 1)
|
||||||
|
if s.shape[1] < slice_size:
|
||||||
|
slice_size = s.shape[1]
|
||||||
|
elif s.shape[1] % slice_size != 0:
|
||||||
|
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
|
||||||
|
new_shape = [-1, slice_size] + list(s.shape[2:])
|
||||||
|
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
|
||||||
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatch(io.ComfyNode):
|
class LatentBatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -435,6 +477,7 @@ class LatentExtension(ComfyExtension):
|
|||||||
LatentInterpolate,
|
LatentInterpolate,
|
||||||
LatentConcat,
|
LatentConcat,
|
||||||
LatentCut,
|
LatentCut,
|
||||||
|
LatentCutToBatch,
|
||||||
LatentBatch,
|
LatentBatch,
|
||||||
LatentBatchSeedBehavior,
|
LatentBatchSeedBehavior,
|
||||||
LatentApplyOperation,
|
LatentApplyOperation,
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.0.3b5
|
comfyui_manager==4.0.3b7
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user