Merge branch 'master' into v3-dynamic-combo
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled

This commit is contained in:
Jedrzej Kosinski 2025-12-01 16:59:38 -08:00
commit 3a5daf7dce
57 changed files with 5706 additions and 1910 deletions

View File

@ -1,5 +1,5 @@
As of the time of writing this you need this preview driver for best results: As of the time of writing this you need this driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
HOW TO RUN: HOW TO RUN:
@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor. Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.

View File

@ -65,11 +65,11 @@ jobs:
contents: "write" contents: "write"
packages: "write" packages: "write"
pull-requests: "read" pull-requests: "read"
name: "Release AMD ROCm 6.4.4" name: "Release AMD ROCm 7.1.1"
uses: ./.github/workflows/stable-release.yml uses: ./.github/workflows/stable-release.yml
with: with:
git_tag: ${{ inputs.git_tag }} git_tag: ${{ inputs.git_tag }}
cache_tag: "rocm644" cache_tag: "rocm711"
python_minor: "12" python_minor: "12"
python_patch: "10" python_patch: "10"
rel_name: "amd" rel_name: "amd"

View File

@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -59,6 +59,9 @@ class UserManager():
user = "default" user = "default"
if args.multi_user and "comfy-user" in request.headers: if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"] user = request.headers["comfy-user"]
# Block System Users (use same error message to prevent probing)
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise KeyError("Unknown user: " + user)
if user not in self.users: if user not in self.users:
raise KeyError("Unknown user: " + user) raise KeyError("Unknown user: " + user)
@ -66,15 +69,16 @@ class UserManager():
return user return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
if type == "userdata": if type == "userdata":
root_dir = user_directory root_dir = folder_paths.get_user_directory()
else: else:
raise KeyError("Unknown filepath type:" + type) raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request) user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user)) user_root = folder_paths.get_public_user_directory(user)
if user_root is None:
return None
path = user_root
# prevent leaving /{type} # prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir: if os.path.commonpath((root_dir, user_root)) != root_dir:
@ -101,7 +105,11 @@ class UserManager():
name = name.strip() name = name.strip()
if not name: if not name:
raise ValueError("username not provided") raise ValueError("username not provided")
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = user_id + "_" + str(uuid.uuid4()) user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name self.users[user_id] = name
@ -132,7 +140,10 @@ class UserManager():
if username in self.users.values(): if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400) return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username) try:
user_id = self.add_user(username)
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
return web.json_response(user_id) return web.json_response(user_id)
@routes.get("/userdata") @routes.get("/userdata")
@ -424,7 +435,7 @@ class UserManager():
return source return source
dest = get_user_data_path(request, check_exists=False, param="dest") dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str): if not isinstance(dest, str):
return dest return dest
overwrite = request.query.get("overwrite", 'true') != "false" overwrite = request.query.get("overwrite", 'true') != "false"

View File

@ -131,7 +131,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")

View File

@ -6,6 +6,7 @@ class LatentFormat:
latent_dimensions = 2 latent_dimensions = 2
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None taesd_decoder_name = None
def process_in(self, latent): def process_in(self, latent):
@ -178,6 +179,54 @@ class Flux(SD3):
def process_out(self, latent): def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor return (latent / self.scale_factor) + self.shift_factor
class Flux2(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors =[
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3 latent_dimensions = 3
@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
] ]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
taesd_decoder_name = "taehv"
class Cosmos1CV8x8x8(LatentFormat): class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16 latent_channels = 16
@ -445,7 +495,7 @@ class Wan21(LatentFormat):
]).view(1, self.latent_channels, 1, 1, 1) ]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO self.taesd_decoder_name = "lighttaew2_1"
def process_in(self, latent): def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_mean = self.latents_mean.to(latent.device, latent.dtype)
@ -516,6 +566,7 @@ class Wan22(Wan21):
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0
self.taesd_decoder_name = "lighttaew2_2"
self.latents_mean = torch.tensor([ self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
@ -670,6 +721,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32 latent_channels = 32
latent_dimensions = 3 latent_dimensions = 3
scale_factor = 1.03682 scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
class Hunyuan3Dv2(LatentFormat): class Hunyuan3Dv2(LatentFormat):
latent_channels = 64 latent_channels = 64

View File

@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x))) return self.out_layer(self.silu(self.in_layer(x)))
@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass @dataclass
@ -98,11 +98,11 @@ class ModulationOut:
class Modulation(nn.Module): class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.is_double = double self.is_double = double
self.multiplier = 6 if double else 3 self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple: def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2: if vec.ndim == 2:
@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor return tensor
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module):
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), if mlp_silu_act:
nn.GELU(approximate="tanh"), self.img_mlp = nn.Sequential(
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
) SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation: if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), if mlp_silu_act:
nn.GELU(approximate="tanh"), self.txt_mlp = nn.Sequential(
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
) SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module):
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: float = None, qk_scale: float = None,
modulation=True, modulation=True,
mlp_silu_act=False,
bias=True,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None
@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5 self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
# qkv and mlp_in # qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out # proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
if modulation: if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else: else:
@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module):
else: else:
mod = vec mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv del qkv
@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module): class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2: if vec.ndim == 2:

View File

@ -15,6 +15,7 @@ from .layers import (
MLPEmbedder, MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
Modulation
) )
@dataclass @dataclass
@ -33,6 +34,11 @@ class FluxParams:
patch_size: int patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
class Flux(nn.Module): class Flux(nn.Module):
@ -58,13 +64,17 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size self.hidden_size = params.hidden_size
self.num_heads = params.num_heads self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = ( self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
) )
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
@ -73,6 +83,9 @@ class Flux(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -81,13 +94,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
if final_layer: if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.global_modulation:
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
)
def forward_orig( def forward_orig(
self, self,
@ -103,9 +133,6 @@ class Flux(nn.Module):
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -118,9 +145,17 @@ class Flux(nn.Module):
if guidance is not None: if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) if self.vector_in is not None:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@ -136,7 +171,10 @@ class Flux(nn.Module):
pe = None pe = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -177,7 +215,13 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -207,7 +251,7 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -234,10 +278,10 @@ class Flux(nn.Module):
h_offset += rope_options.get("shift_y", 0.0) h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0) w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@ -259,10 +303,10 @@ class Flux(nn.Module):
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset") ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method == "index":
index += 1 index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif ref_latents_method == "uxo": elif ref_latents_method == "uxo":
@ -286,7 +330,11 @@ class Flux(nn.Module):
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension import comfy.patcher_extension
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int, n_heads: int,
n_kv_heads: Optional[int], n_kv_heads: Optional[int],
qk_norm: bool, qk_norm: bool,
out_bias: bool = False,
operation_settings={}, operation_settings={},
): ):
""" """
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear( self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim, n_heads * self.head_dim,
dim, dim,
bias=False, bias=out_bias,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq) xq = self.q_norm(xq)
xk = self.k_norm(xk) xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep >= 1:
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float, norm_eps: float,
qk_norm: bool, qk_norm: bool,
modulation=True, modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={}, operation_settings={},
) -> None: ) -> None:
""" """
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward( self.feed_forward = FeedForward(
dim=dim, dim=dim,
hidden_dim=4 * dim, hidden_dim=dim,
multiple_of=multiple_of, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings, operation_settings=operation_settings,
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
self.adaLN_modulation = nn.Sequential( if z_image_modulation:
nn.SiLU(), self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(dim, 1024), min(dim, 256),
4 * dim, 4 * dim,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
), ),
) )
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward( def forward(
self, self,
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT. The final layer of NextDiT.
""" """
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__() super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm( self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size, hidden_size,
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(hidden_size, 1024), min(hidden_size, min_mod),
hidden_size, hidden_size,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32, n_heads: int = 32,
n_kv_heads: Optional[int] = None, n_kv_heads: Optional[int] = None,
multiple_of: int = 256, multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None, ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = False, qk_norm: bool = False,
cap_feat_dim: int = 5120, cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56), axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512), axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear( self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels, in_features=patch_size * patch_size * in_channels,
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps, norm_eps,
qk_norm, qk_norm,
modulation=True, modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
] ]
) )
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier, ffn_dim_multiplier,
norm_eps, norm_eps,
qk_norm, qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims self.axes_dims = axes_dims
self.axes_lens = axes_lens self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
@ -503,108 +506,54 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
dtype = x[0].dtype
if cap_mask is not None: if self.pad_tokens_multiple is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist() pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
else: cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask): cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x] B, C, H, W = x.shape
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max( rope_options = transformer_options.get("rope_options", None)
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) h_scale = 1.0
) w_scale = 1.0
max_cap_len = max(l_effective_cap_len) h_start = 0
max_img_len = max(l_effective_img_len) w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
for i in range(bsz): H_tokens, W_tokens = H // pH, W // pW
cap_len = l_effective_cap_len[i] x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
img_len = l_effective_img_len[i] x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
H, W = img_sizes[i] x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
H_tokens, W_tokens = H // pH, W // pW x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None) if self.pad_tokens_multiple is not None:
h_scale = 1.0 pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
w_scale = 1.0 x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
h_start = 0 x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image padded_img_mask = None
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -627,7 +576,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features y: (N,) tensor of text tokens/features
""" """
t = self.t_embedder(t, dtype=x.dtype) # (N, D) t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
from einops import rearrange
import comfy.model_management
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False): def __init__(self, sample: bool = False):
@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ddconfig.get("batch_norm_latent", False):
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
self.bn.eval()
else:
self.bn = None
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params() params = super().get_autoencoder_params()
return params return params
@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0) z = torch.cat(z, 0)
z, reg_log = self.regularization(z) z, reg_log = self.regularization(z)
if self.bn is not None:
z = rearrange(z,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = torch.nn.functional.batch_norm(z,
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
momentum=self.bn_momentum,
eps=self.bn_eps)
if return_reg_log: if return_reg_log:
return z, reg_log return z, reg_log
return z return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.bn is not None:
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
z = z * s + m
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
if self.max_batch_size is None: if self.max_batch_size is None:
dec = self.post_quant_conv(z) dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs) dec = self.decoder(dec, **decoder_kwargs)

View File

@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size

View File

@ -313,6 +313,15 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
return key_map return key_map

View File

@ -898,12 +898,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None: if attention_mask is not None:
shape = kwargs["noise"].shape shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"] mask_ref_size = kwargs.get("attention_mask_img_shape", None)
# the model will pad to the patch size, and then divide if mask_ref_size is not None:
# essentially dividing and rounding up # the model will pad to the patch size, and then divide
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) # essentially dividing and rounding up
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5) guidance = kwargs.get("guidance", 3.5)
if guidance is not None: if guidance is not None:
@ -925,9 +926,19 @@ class Flux(BaseModel):
out = {} out = {}
ref_latents = kwargs.get("reference_latents", None) ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None: if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
target_text_len = 512
if cross_attn.shape[1] < target_text_len:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class GenmoMochi(BaseModel): class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1103,9 +1114,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum(): if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):

View File

@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
dit_config["axes_dim"] = [32, 32, 32, 32]
dit_config["num_heads"] = 48
dit_config["mlp_ratio"] = 3.0
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
patch_size = 1
else:
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["num_heads"] = 24
dit_config["mlp_ratio"] = 4.0
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
patch_size = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
patch_size = 2 dit_config["hidden_size"] = 3072
dit_config["context_in_dim"] = 4096
dit_config["patch_size"] = patch_size dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix) in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys: if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) w = state_dict[in_key]
dit_config["out_channels"] = 16 dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys: if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma" dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64 dit_config["in_channels"] = 64
@ -388,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512] if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1

View File

@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory() loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0: if lowvram_model_memory == 0:
@ -1012,9 +1012,18 @@ def force_channels_last():
STREAMS = {} STREAMS = {}
NUM_STREAMS = 1 NUM_STREAMS = 0
if args.async_offload: if args.async_offload is not None:
NUM_STREAMS = 2 NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia
if is_nvidia():
NUM_STREAMS = 2
if args.disable_async_offload:
NUM_STREAMS = 0
if NUM_STREAMS > 0:
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device): def current_stream(device):
@ -1030,7 +1039,10 @@ def current_stream(device):
stream_counters = {} stream_counters = {}
def get_offload_stream(device): def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0) stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1: if NUM_STREAMS == 0:
return None
if torch.compiler.is_compiling():
return None return None
if device in STREAMS: if device in STREAMS:
@ -1043,7 +1055,9 @@ def get_offload_stream(device):
elif is_device_cuda(device): elif is_device_cuda(device):
ss = [] ss = []
for k in range(NUM_STREAMS): for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0)) s1 = torch.cuda.Stream(device=device, priority=0)
s1.as_context = torch.cuda.stream
ss.append(s1)
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
@ -1051,7 +1065,9 @@ def get_offload_stream(device):
elif is_device_xpu(device): elif is_device_xpu(device):
ss = [] ss = []
for k in range(NUM_STREAMS): for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0)) s1 = torch.xpu.Stream(device=device, priority=0)
s1.as_context = torch.xpu.stream
ss.append(s1)
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
@ -1069,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype: if dtype is None or weight.dtype == dtype:
return weight return weight
if stream is not None: if stream is not None:
with stream: wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy)
if stream is not None: if stream is not None:
with stream: wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device) r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking) r.copy_(weight, non_blocking=non_blocking)
else: else:
@ -1098,13 +1121,14 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor): def pin_memory(tensor):
global TOTAL_PINNED_MEMORY global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0: if MAX_PINNED_MEMORY <= 0:
return False return False
if type(tensor) is not torch.nn.parameter.Parameter: if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False return False
if not is_device_cpu(tensor.device): if not is_device_cpu(tensor.device):
@ -1124,6 +1148,9 @@ def pin_memory(tensor):
return False return False
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size TOTAL_PINNED_MEMORY += size

View File

@ -132,7 +132,7 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None: if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
@ -148,6 +148,15 @@ class LowVramPatch:
else: else:
return out return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
convert_func = None convert_func = None
@ -231,7 +240,6 @@ class ModelPatcher:
self.object_patches_backup = {} self.object_patches_backup = {}
self.weight_wrapper_patches = {} self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
@ -270,6 +278,9 @@ class ModelPatcher:
if not hasattr(self.model, 'current_weight_patches_uuid'): if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None self.model.current_weight_patches_uuid = None
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
return self.size return self.size
@ -286,7 +297,7 @@ class ModelPatcher:
return self.model.lowvram_patch_counter return self.model.lowvram_patch_counter
def clone(self): def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -663,7 +674,16 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules skip = True # skip random weights in non leaf modules
break break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params)) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -677,20 +697,22 @@ class ModelPatcher:
load_completely = [] load_completely = []
offloaded = [] offloaded = []
offload_buffer = 0
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] module_offload_mem, module_mem, n, m, params = x
m = x[2]
params = x[3]
module_mem = x[0]
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"): if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory: if not lowvram_fits:
offload_buffer = potential_offload
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
lowvram_mem_counter += module_mem lowvram_mem_counter += module_mem
@ -724,9 +746,11 @@ class ModelPatcher:
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory: if full_load or lowvram_fits:
mem_counter += module_mem mem_counter += module_mem
load_completely.append((module_mem, n, m, params)) load_completely.append((module_mem, n, m, params))
else:
offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"): if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -767,7 +791,7 @@ class ModelPatcher:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
@ -779,6 +803,7 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid self.model.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@ -832,6 +857,7 @@ class ModelPatcher:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"): if hasattr(m, "comfy_patched_weights"):
@ -850,13 +876,14 @@ class ModelPatcher:
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
unload_list.sort() unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
for unload in unload_list: for unload in unload_list:
if memory_to_free < memory_freed: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break break
module_mem = unload[0] module_offload_mem, module_mem, n, m, params = unload
n = unload[1]
m = unload[2] potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -907,15 +934,18 @@ class ModelPatcher:
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
for param in params: for param in params:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if offload_stream is not None: if offload_stream is not None:
wf_context = offload_stream wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else: else:
wf_context = contextlib.nullcontext() wf_context = contextlib.nullcontext()
@ -117,6 +119,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context: with wf_context:
weight = weight.to(dtype=dtype) weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
@ -502,7 +506,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight return weight
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@ -540,115 +544,136 @@ if CUBLAS_IS_AVAILABLE:
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS from .quant_ops import QuantizedTensor, QUANT_ALGOS
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp): def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def __init__( class MixedPrecisionOps(manual_cast):
self, _layer_quant_config = layer_quant_config
in_features: int, _compute_dtype = compute_dtype
out_features: int, _full_precision_mm = full_precision_mm
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} class Linear(torch.nn.Module, CastWeightBiasOp):
# self.factory_kwargs = {"device": device, "dtype": dtype} def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.in_features = in_features self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
self.out_features = out_features # self.factory_kwargs = {"device": device, "dtype": dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
def reset_parameters(self): self.tensor_class = None
return None self._full_precision_mm = MixedPrecisionOps._full_precision_mm
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def reset_parameters(self):
strict, missing_keys, unexpected_keys, error_msgs): return None
device = self.factory_kwargs["device"] def _load_from_state_dict(self, state_dict, prefix, local_metadata,
layer_name = prefix.rstrip('.') strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key] device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
if layer_name not in MixedPrecisionOps._layer_quant_config: manually_loaded_keys = [weight_key]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format] if layer_name not in MixedPrecisionOps._layer_quant_config:
self.layout_type = qconfig["comfy_tensor_layout"] self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
weight_scale_key = f"{prefix}weight_scale" qconfig = QUANT_ALGOS[quant_format]
layout_params = { self.layout_type = qconfig["comfy_tensor_layout"]
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter( weight_scale_key = f"{prefix}weight_scale"
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), layout_params = {
requires_grad=False 'scale': state_dict.pop(weight_scale_key, None),
) 'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
for param_name in qconfig["parameters"]: self.weight = torch.nn.Parameter(
param_key = f"{prefix}{param_name}" QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
_v = state_dict.pop(param_key, None) requires_grad=False
if _v is None: )
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
for key in manually_loaded_keys: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias): for key in manually_loaded_keys:
return torch.nn.functional.linear(input, weight, bias) if key in missing_keys:
missing_keys.remove(key)
def forward_comfy_cast_weights(self, input): def _forward(self, input, weight, bias):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) return torch.nn.functional.linear(input, weight, bias)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs): def forward_comfy_cast_weights(self, input):
run_every_op() weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: def forward(self, input, *args, **kwargs):
return self.forward_comfy_cast_weights(input, *args, **kwargs) run_every_op()
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
MixedPrecisionOps._compute_dtype = compute_dtype if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -1,6 +1,7 @@
import torch import torch
import logging import logging
from typing import Tuple, Dict from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {} _LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {} _GENERIC_UTILS = {}
@ -228,6 +229,14 @@ class QuantizedTensor(torch.Tensor):
new_kwargs = dequant_arg(kwargs) new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs) return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
# ============================================================================== # ==============================================================================
# Generic Utilities (Layout-Agnostic Operations) # Generic Utilities (Layout-Agnostic Operations)
@ -338,6 +347,18 @@ def generic_copy_(func, args, kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs): def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True return True
@ -373,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
- orig_dtype: Original dtype before quantization (for casting back) - orig_dtype: Original dtype before quantization (for casting back)
""" """
@classmethod @classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype orig_dtype = tensor.dtype
if scale is None: if scale is None:
@ -383,22 +404,29 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale) scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = scale.to(device=tensor.device, dtype=torch.float32)
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) if inplace_ops:
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality' tensor *= (1.0 / scale).to(tensor.dtype)
# lp_amax = torch.finfo(dtype).max else:
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = { layout_params = {
'scale': scale, 'scale': scale,
'orig_dtype': orig_dtype 'orig_dtype': orig_dtype
} }
return qdata, layout_params return tensor, layout_params
@staticmethod @staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs): def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale plain_tensor.mul_(scale)
return plain_tensor
@classmethod @classmethod
def get_plain_tensors(cls, qtensor): def get_plain_tensors(cls, qtensor):

View File

@ -52,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -59,6 +60,8 @@ import comfy.lora_convert
import comfy.hooks import comfy.hooks
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.taesd.taesd import comfy.taesd.taesd
import comfy.taesd.taehv
import comfy.latent_formats
import comfy.ldm.flux.redux import comfy.ldm.flux.redux
@ -356,7 +359,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32: elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -382,6 +385,17 @@ class VAE:
self.upscale_ratio = 4 self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
if 'bn.running_mean' in sd:
ddconfig["batch_norm_latent"] = True
self.downscale_ratio *= 2
self.upscale_ratio *= 2
self.latent_channels *= 4
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd: if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else: else:
@ -496,13 +510,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8) self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = 16 self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
@ -572,6 +587,35 @@ class VAE:
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32] self.working_dtypes = [torch.float32]
self.crop_input = False self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
self.process_input = self.process_output = lambda image: image
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -917,7 +961,12 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -935,6 +984,10 @@ class TEModel(Enum):
QWEN25_7B = 11 QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12 BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13 GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -967,6 +1020,15 @@ def detect_te_model(sd):
if weight.shape[0] == 512: if weight.shape[0] == 512:
return TEModel.QWEN25_7B return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd: if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
return TEModel.LLAMA3_8 return TEModel.LLAMA3_8
return None return None
@ -1081,6 +1143,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
@ -1142,6 +1211,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0 parameters = 0
for c in clip_data: for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

View File

@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
operations = model_options.get("custom_operations", None) operations = model_options.get("custom_operations", None)
scaled_fp8 = None scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
if operations is None: if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None) layer_quant_config = None
if scaled_fp8 is not None: if quantization_metadata is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else: else:
operations = comfy.ops.manual_cast # Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options): def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx) layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all": if isinstance(self.layer, list) or self.layer == "all":
pass pass
elif layer_idx is None or abs(layer_idx) > self.num_layers: elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask_model = attention_mask attention_mask_model = attention_mask
if self.layer == "all": if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all" intermediate_output = "all"
else: else:
intermediate_output = self.layer_idx intermediate_output = self.layer_idx

View File

@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -741,6 +742,37 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out return out
class Flux2(Flux):
unet_config = {
"image_model": "flux2",
}
sampling_settings = {
"shift": 2.02,
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "mochi_preview", "image_model": "mochi_preview",
@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0, "shift": 6.0,
} }
memory_usage_factor = 1.2 memory_usage_factor = 1.4
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE): class WAN21_T2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1422,6 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models += [SVD_img2vid] models += [SVD_img2vid]

171
comfy/taesd/taehv.py Normal file
View File

@ -0,0 +1,171 @@
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple, deque
import comfy.ops
operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
BT, C, H, W = x.shape
T = BT // B
_x = x.reshape(B, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
BT, C, H, W = x.shape
T = BT // B
x = x.view(B, T, C, H, W)
else:
out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.popleft()
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
del xt
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt.detach().clone()
else:
xt_new = b(xt, mem[i])
mem[i] = xt.detach().clone()
del xt
work_queue.appendleft(TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt.detach().clone())
if len(mem[i]) == b.stride:
B, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
mem[i] = []
work_queue.appendleft(TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
work_queue.appendleft(TWorkItem(xt_next, i+1))
del xt
else:
xt = b(xt)
work_queue.appendleft(TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
super().__init__()
self.image_channels = 3
self.patch_size = 1
self.latent_channels = latent_channels
self.parallel = parallel
self.latent_format = latent_format
self.show_progress_bar = show_progress_bar
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
act_func = nn.ReLU(inplace=True)
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -1,10 +1,13 @@
from comfy import sd1_clip from comfy import sd1_clip
import comfy.text_encoders.t5 import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.llama
import comfy.model_management import comfy.model_management
from transformers import T5TokenizerFast from transformers import T5TokenizerFast, LlamaTokenizerFast
import torch import torch
import os import os
import json
import base64
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_
def load_mistral_tokenizer(data):
if torch.is_tensor(data):
data = data.numpy().tobytes()
try:
from transformers.integrations.mistral import MistralConverter
except ModuleNotFoundError:
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
mistral_vocab = json.loads(data)
special_tokens = {}
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
if r >= max_vocab:
continue
vocab[base64.b64decode(w["token_bytes"])] = r
for w in mistral_vocab["special_tokens"]:
if "token_bytes" in w:
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
else:
special_tokens[w["token_str"]] = w["rank"]
all_special = []
for v in special_tokens:
all_special.append(v)
special_tokens.update(vocab)
vocab = special_tokens
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
class MistralTokenizerClass:
@staticmethod
def from_pretrained(path, **kwargs):
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40:
textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_

View File

@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict: if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
return out return out

View File

@ -34,6 +34,28 @@ class Llama2Config:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Mistral3Small24BConfig:
vocab_size: int = 131072
hidden_size: int = 5120
intermediate_size: int = 32768
num_hidden_layers: int = 40
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
vocab_size: int = 151936 vocab_size: int = 151936
@ -56,6 +78,28 @@ class Qwen25_3BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
vocab_size: int = 152064 vocab_size: int = 152064
@ -412,8 +456,12 @@ class Llama2_(nn.Module):
intermediate = None intermediate = None
all_intermediate = None all_intermediate = None
only_layers = None
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output == "all": if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = [] all_intermediate = []
intermediate_output = None intermediate_output = None
elif intermediate_output < 0: elif intermediate_output < 0:
@ -421,7 +469,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer( x = layer(
x=x, x=x,
attention_mask=mask, attention_mask=mask,
@ -435,7 +484,8 @@ class Llama2_(nn.Module):
x = self.norm(x) x = self.norm(x)
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None: if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1) intermediate = torch.cat(all_intermediate, dim=1)
@ -465,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Mistral3Small24B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Mistral3Small24BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module): class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@ -474,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -179,36 +179,36 @@
"special": false "special": false
}, },
"151665": { "151665": {
"content": "<|img|>", "content": "<tool_response>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151666": { "151666": {
"content": "<|endofimg|>", "content": "</tool_response>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151667": { "151667": {
"content": "<|meta|>", "content": "<think>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151668": { "151668": {
"content": "<|endofmeta|>", "content": "</think>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
} }
}, },
"additional_special_tokens": [ "additional_special_tokens": [

View File

@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -675,6 +675,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
def z_image_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("n_layers", 0)
hidden_size = mmdit_config.get("dim", 0)
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
key_map = {}
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
for end in ("weight", "bias"):
k = "{}.attention.".format(prefix_from)
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attention.norm_q.weight": "attention.q_norm.weight",
"attention.norm_k.weight": "attention.k_norm.weight",
"attention.to_out.0.weight": "attention.out.weight",
"attention.to_out.0.bias": "attention.out.bias",
"attention_norm1.weight": "attention_norm1.weight",
"attention_norm2.weight": "attention_norm2.weight",
"feed_forward.w1.weight": "feed_forward.w1.weight",
"feed_forward.w2.weight": "feed_forward.w2.weight",
"feed_forward.w3.weight": "feed_forward.w3.weight",
"ffn_norm1.weight": "ffn_norm1.weight",
"ffn_norm2.weight": "ffn_norm2.weight",
}
if has_adaln:
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
for k, v in block_map.items():
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
for i in range(n_layers):
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
for i in range(n_context_refiner):
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
for i in range(n_noise_refiner):
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
MAP_BASIC = [
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
("x_embedder.weight", "all_x_embedder.2-1.weight"),
("x_embedder.bias", "all_x_embedder.2-1.bias"),
("x_pad_token", "x_pad_token"),
("cap_embedder.0.weight", "cap_embedder.0.weight"),
("cap_embedder.1.weight", "cap_embedder.1.weight"),
("cap_embedder.1.bias", "cap_embedder.1.bias"),
("cap_pad_token", "cap_pad_token"),
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
]
for c, diffusers in MAP_BASIC:
key_map[diffusers] = "{}{}".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0): def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size: if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size) return tensor.narrow(dim, 0, batch_size)

View File

@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm( lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape) ).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose( weight = weight_decompose(
dora_scale, dora_scale,

View File

@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now") raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: extra_kwargs = {}
if format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams # Add metadata before writing any streams
if metadata is not None: if metadata is not None:
for key, value in metadata.items(): for key, value in metadata.items():

View File

@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
# ) # )
class Flux2ProGenerateRequest(BaseModel):
prompt: str = Field(...)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
seed: int | None = Field(None)
prompt_upsampling: bool | None = Field(None)
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel): class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.') prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format') input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
class BFLFluxProGenerateResponse(BaseModel): class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.') id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description='URL to poll for the generation result.') polling_url: str = Field(..., description="URL to poll for the generation result.")
cost: float | None = Field(None, description="Price in cents")
class BFLStatus(str, Enum): class BFLStatus(str, Enum):

View File

@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
mimeType: GeminiMimeType | None = Field(None) mimeType: GeminiMimeType | None = Field(None)
class GeminiFileData(BaseModel):
fileUri: str | None = Field(None)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel): class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None) inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None) text: str | None = Field(None)

View File

@ -0,0 +1,66 @@
from pydantic import BaseModel, Field
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniParamImage(BaseModel):
image_url: str = Field(...)
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
class OmniParamVideo(BaseModel):
video_url: str = Field(...)
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
keep_original_sound: str = Field(..., description="'yes' or 'no'")
class OmniProFirstLastFrameRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniProReferences2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
image_list: list[OmniParamImage] | None = Field(
None, max_length=7, description="Max length 4 when video is present."
)
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
class TaskStatusVideoResult(BaseModel):
duration: str | None = Field(None, description="Total video duration")
id: str | None = Field(None, description="Generated video ID")
url: str | None = Field(None, description="URL for generated video")
class TaskStatusVideoResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
class TaskStatusVideoResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusVideoResults | None = Field(None)
class TaskStatusVideoResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: TaskStatusVideoResponseData | None = Field(None)

View File

@ -1,34 +1,21 @@
from typing import Optional, Union from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class Image2(BaseModel): class VeoRequestInstanceImage(BaseModel):
bytesBase64Encoded: str bytesBase64Encoded: str | None = Field(None)
gcsUri: Optional[str] = None gcsUri: str | None = Field(None)
mimeType: Optional[str] = None mimeType: str | None = Field(None)
class Image3(BaseModel): class VeoRequestInstance(BaseModel):
bytesBase64Encoded: Optional[str] = None image: VeoRequestInstanceImage | None = Field(None)
gcsUri: str lastFrame: VeoRequestInstanceImage | None = Field(None)
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
prompt: str = Field(..., description='Text description of the video') prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum): class VeoRequestParameters(BaseModel):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9']) aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None enhancePrompt: Optional[bool] = None
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
description='Generate audio for the video. Only supported by veo 3 models.', description='Generate audio for the video. Only supported by veo 3 models.',
) )
negativePrompt: Optional[str] = None negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
sampleCount: Optional[int] = None sampleCount: Optional[int] = None
seed: Optional[int] = None seed: Optional[int] = None
storageUri: Optional[str] = Field( storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video' None, description='Optional Cloud Storage URI to upload the video'
) )
resolution: str | None = Field(None)
class VeoGenVidRequest(BaseModel): class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None instances: list[VeoRequestInstance] | None = Field(None)
parameters: Optional[Parameters1] = None parameters: VeoRequestParameters | None = Field(None)
class VeoGenVidResponse(BaseModel): class VeoGenVidResponse(BaseModel):

View File

@ -1,7 +1,7 @@
from inspect import cleandoc from inspect import cleandoc
from typing import Optional
import torch import torch
from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest, BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateRequest,
BFLFluxProGenerateResponse, BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest, BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse, BFLFluxStatusResponse,
BFLStatus, BFLStatus,
Flux2ProGenerateRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
get_number_of_images,
poll_op, poll_op,
resize_mask_to_image, resize_mask_to_image,
sync_op, sync_op,
@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: Optional[torch.Tensor] = None, image_prompt: torch.Tensor | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: Optional[torch.Tensor] = None, input_image: torch.Tensor | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode):
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProImageNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProImageNode",
display_name="Flux 1.1 [pro] Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image_prompt",
optional=True,
),
# "image_prompt_strength": (
# IO.FLOAT,
# {
# "default": 0.1,
# "min": 0.0,
# "max": 1.0,
# "step": 0.01,
# "tooltip": "Blend between the prompt and the image prompt.",
# },
# ),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
prompt_upsampling,
width: int,
height: int,
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
) -> IO.NodeOutput:
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1/generate",
method="POST",
),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
width=width,
height=height,
seed=seed,
image_prompt=image_prompt,
),
)
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
""" """
Outpaints image based on prompt. Outpaints image based on prompt.
@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ProImageNode",
display_name="Flux.2 [pro] Image",
category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or edit",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=2048,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=2048,
step=32,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
width: int,
height: int,
seed: int,
prompt_upsampling: bool,
images: torch.Tensor | None = None,
) -> IO.NodeOutput:
reference_images = {}
if images is not None:
if get_number_of_images(images) > 9:
raise ValueError("The current maximum number of supported images is 9.")
for image_index in range(images.shape[0]):
key_name = f"input_image_{image_index + 1}" if image_index else "input_image"
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
width=width,
height=height,
seed=seed,
prompt_upsampling=prompt_upsampling,
**reference_images,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
FluxProUltraImageNode, FluxProUltraImageNode,
# FluxProImageNode,
FluxKontextProImageNode, FluxKontextProImageNode,
FluxKontextMaxImageNode, FluxKontextMaxImageNode,
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode,
] ]

View File

@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
""" """
import base64 import base64
import json
import os import os
import time
import uuid
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Literal from typing import Literal
@ -20,6 +17,7 @@ from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import ( from comfy_api_nodes.apis.gemini_api import (
GeminiContent, GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest, GeminiGenerateContentRequest,
GeminiGenerateContentResponse, GeminiGenerateContentResponse,
GeminiImageConfig, GeminiImageConfig,
@ -38,10 +36,10 @@ from comfy_api_nodes.util import (
get_number_of_images, get_number_of_images,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
upload_images_to_comfyapi,
validate_string, validate_string,
video_to_base64_string, video_to_base64_string,
) )
from server import PromptServer
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
@ -68,24 +66,43 @@ class GeminiImageModel(str, Enum):
gemini_2_5_flash_image = "gemini-2.5-flash-image" gemini_2_5_flash_image = "gemini-2.5-flash-image"
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: async def create_image_parts(
""" cls: type[IO.ComfyNode],
Convert image tensor input to Gemini API compatible parts. images: torch.Tensor,
image_limit: int = 0,
Args: ) -> list[GeminiPart]:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = [] image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]): if image_limit < 0:
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
max_images=num_url_images,
)
for reference_image_url in reference_images_urls:
image_parts.append(
GeminiPart(
fileData=GeminiFileData(
mimeType=GeminiMimeType.image_png,
fileUri=reference_image_url,
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append( image_parts.append(
GeminiPart( GeminiPart(
inlineData=GeminiInlineData( inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png, mimeType=GeminiMimeType.image_png,
data=image_as_b64, data=tensor_to_base64_string(images[idx]),
) )
) )
) )
@ -338,8 +355,7 @@ class GeminiNode(IO.ComfyNode):
# Add other modal parts # Add other modal parts
if images is not None: if images is not None:
image_parts = create_image_parts(images) parts.extend(await create_image_parts(cls, images))
parts.extend(image_parts)
if audio is not None: if audio is not None:
parts.extend(cls.create_audio_parts(audio)) parts.extend(cls.create_audio_parts(audio))
if video is not None: if video is not None:
@ -364,29 +380,6 @@ class GeminiNode(IO.ComfyNode):
) )
output_text = get_text_from_response(response) output_text = get_text_from_response(response)
if output_text:
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text or "Empty response from Gemini model...") return IO.NodeOutput(output_text or "Empty response from Gemini model...")
@ -562,8 +555,7 @@ class GeminiImage(IO.ComfyNode):
image_config = GeminiImageConfig(aspectRatio=aspect_ratio) image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
if images is not None: if images is not None:
image_parts = create_image_parts(images) parts.extend(await create_image_parts(cls, images))
parts.extend(image_parts)
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
@ -582,30 +574,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
class GeminiImage2(IO.ComfyNode): class GeminiImage2(IO.ComfyNode):
@ -702,7 +671,7 @@ class GeminiImage2(IO.ComfyNode):
if images is not None: if images is not None:
if get_number_of_images(images) > 14: if get_number_of_images(images) > 14:
raise ValueError("The current maximum number of supported images is 14.") raise ValueError("The current maximum number of supported images is 14.")
parts.extend(create_image_parts(images)) parts.extend(await create_image_parts(cls, images))
if files is not None: if files is not None:
parts.extend(files) parts.extend(files)
@ -725,30 +694,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
output_text = get_text_from_response(response)
if output_text:
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(get_image_from_response(response), output_text)
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -4,15 +4,13 @@ For source of truth on the allowed permutations of request fields, please refere
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
""" """
from __future__ import annotations
from typing import Optional, TypeVar
import math
import logging import logging
import math
from typing_extensions import override
import torch import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
KlingCameraControl, KlingCameraControl,
KlingCameraConfig, KlingCameraConfig,
@ -50,25 +48,31 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName, KlingCharacterEffectModelName,
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.kling_api import (
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
TaskStatusVideoResponse,
)
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
validate_image_dimensions, ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
tensor_to_base64_string,
validate_string,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
upload_video_to_comfyapi,
download_url_to_video_output,
sync_op,
ApiEndpoint,
poll_op,
) )
from comfy_api.input_impl import VideoFromFile
from comfy_api.input.basic_types import AudioInput
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import ComfyExtension, IO
KLING_API_VERSION = "v1" KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
@ -94,8 +98,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32
AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EFFECTS = 320
AVERAGE_DURATION_VIDEO_EXTEND = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320
R = TypeVar("R")
MODE_TEXT2VIDEO = { MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
@ -130,6 +132,8 @@ MODE_START_END_FRAME = {
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
"pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
"pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
} }
""" """
Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
@ -206,6 +210,20 @@ VOICES_CONFIG = {
} }
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusVideoResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
def is_valid_camera_control_configs(configs: list[float]) -> bool: def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero.""" """Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs) return any(not math.isclose(value, 0.0) for value in configs)
@ -296,7 +314,7 @@ def get_video_from_response(response) -> KlingVideoResult:
return video return video
def get_video_url_from_response(response) -> Optional[str]: def get_video_url_from_response(response) -> str | None:
"""Returns the first video url from the Kling video generation task result. """Returns the first video url from the Kling video generation task result.
Will not raise an error if the response is not valid. Will not raise an error if the response is not valid.
""" """
@ -315,7 +333,7 @@ def get_images_from_response(response) -> list[KlingImageResult]:
return images return images
def get_images_urls_from_response(response) -> Optional[str]: def get_images_urls_from_response(response) -> str | None:
"""Returns the list of image urls from the Kling image generation task result. """Returns the list of image urls from the Kling image generation task result.
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
""" """
@ -349,7 +367,7 @@ async def execute_text2video(
model_mode: str, model_mode: str,
duration: str, duration: str,
aspect_ratio: str, aspect_ratio: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: KlingCameraControl | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
task_creation_response = await sync_op( task_creation_response = await sync_op(
@ -394,8 +412,8 @@ async def execute_image2video(
model_mode: str, model_mode: str,
aspect_ratio: str, aspect_ratio: str,
duration: str, duration: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: KlingCameraControl | None = None,
end_frame: Optional[torch.Tensor] = None, end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
validate_input_image(start_frame) validate_input_image(start_frame)
@ -451,9 +469,9 @@ async def execute_video_effect(
model_name: str, model_name: str,
duration: KlingVideoGenDuration, duration: KlingVideoGenDuration,
image_1: torch.Tensor, image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None, image_2: torch.Tensor | None = None,
model_mode: Optional[KlingVideoGenMode] = None, model_mode: KlingVideoGenMode | None = None,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[InputImpl.VideoFromFile, str, str]:
if dual_character: if dual_character:
request_input_field = KlingDualCharacterEffectInput( request_input_field = KlingDualCharacterEffectInput(
model_name=model_name, model_name=model_name,
@ -499,13 +517,13 @@ async def execute_video_effect(
async def execute_lipsync( async def execute_lipsync(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
video: VideoInput, video: Input.Video,
audio: Optional[AudioInput] = None, audio: Input.Audio | None = None,
voice_language: Optional[str] = None, voice_language: str | None = None,
model_mode: Optional[str] = None, model_mode: str | None = None,
text: Optional[str] = None, text: str | None = None,
voice_speed: Optional[float] = None, voice_speed: float | None = None,
voice_id: Optional[str] = None, voice_id: str | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if text: if text:
validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC)
@ -740,6 +758,386 @@ class KlingTextToVideoNode(IO.ComfyNode):
) )
class OmniProTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
display_name="Kling Omni Text to Video (Pro)",
category="api node/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
),
)
return await finish_omni_video_task(cls, response)
class OmniProFirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling Omni First-Last-Frame to Video (Pro)",
category="api node/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("duration", options=["5", "10"]),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
optional=True,
tooltip="An optional end frame for the video. "
"This cannot be used simultaneously with 'reference_images'.",
),
IO.Image.Input(
"reference_images",
optional=True,
tooltip="Up to 6 additional reference images.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
duration: int,
first_frame: Input.Image,
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
OmniParamImage(
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
type="first_frame",
)
]
if end_frame is not None:
validate_image_dimensions(end_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_list.append(
OmniParamImage(
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
type="end_frame",
)
)
if reference_images is not None:
if get_number_of_images(reference_images) > 6:
raise ValueError("The maximum number of reference images allowed is 6.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
duration=str(duration),
image_list=image_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
display_name="Kling Omni Image to Video (Pro)",
category="api node/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input(
"reference_images",
tooltip="Up to 7 reference images.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
reference_images: Input.Image,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProVideoToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
display_name="Kling Omni Video to Video (Pro)",
category="api node/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Image.Input(
"reference_images",
tooltip="Up to 4 additional reference images.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
reference_video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
image_list: list[OmniParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 4:
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
video_list = [
OmniParamVideo(
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
refer_type="feature",
keep_original_sound="yes" if keep_original_sound else "no",
)
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list if image_list else None,
video_list=video_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProEditVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
display_name="Kling Omni Edit Video (Pro)",
category="api node/video/Kling",
description="Edit an existing video with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Image.Input(
"reference_images",
tooltip="Up to 4 additional reference images.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
image_list: list[OmniParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 4:
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
video_list = [
OmniParamVideo(
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
refer_type="base",
keep_original_sound="yes" if keep_original_sound else "no",
)
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=None,
duration=None,
image_list=image_list if image_list else None,
video_list=video_list,
),
)
return await finish_omni_video_task(cls, response)
class KlingCameraControlT2VNode(IO.ComfyNode): class KlingCameraControlT2VNode(IO.ComfyNode):
""" """
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -787,7 +1185,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
negative_prompt: str, negative_prompt: str,
cfg_scale: float, cfg_scale: float,
aspect_ratio: str, aspect_ratio: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: KlingCameraControl | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_text2video( return await execute_text2video(
cls, cls,
@ -854,8 +1252,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
mode: str, mode: str,
aspect_ratio: str, aspect_ratio: str,
duration: str, duration: str,
camera_control: Optional[KlingCameraControl] = None, camera_control: KlingCameraControl | None = None,
end_frame: Optional[torch.Tensor] = None, end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_image2video( return await execute_image2video(
cls, cls,
@ -965,15 +1363,11 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"),
IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0),
IO.Combo.Input( IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
"aspect_ratio",
options=[i.value for i in KlingVideoGenAspectRatio],
default="16:9",
),
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[2], default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -1170,7 +1564,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
category="api node/video/Kling", category="api node/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.", description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), IO.Image.Input(
"image",
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
),
IO.Combo.Input( IO.Combo.Input(
"effect_scene", "effect_scene",
options=[i.value for i in KlingSingleImageEffectsScene], options=[i.value for i in KlingSingleImageEffectsScene],
@ -1254,8 +1651,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
video: VideoInput, video: Input.Video,
audio: AudioInput, audio: Input.Audio,
voice_language: str, voice_language: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_lipsync( return await execute_lipsync(
@ -1314,7 +1711,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
video: VideoInput, video: Input.Video,
text: str, text: str,
voice: str, voice: str,
voice_speed: float, voice_speed: float,
@ -1471,7 +1868,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
human_fidelity: float, human_fidelity: float,
n: int, n: int,
aspect_ratio: KlingImageGenAspectRatio, aspect_ratio: KlingImageGenAspectRatio,
image: Optional[torch.Tensor] = None, image: torch.Tensor | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
@ -1533,6 +1930,11 @@ class KlingExtension(ComfyExtension):
KlingImageGenerationNode, KlingImageGenerationNode,
KlingSingleImageVideoEffectNode, KlingSingleImageVideoEffectNode,
KlingDualCharacterVideoEffectNode, KlingDualCharacterVideoEffectNode,
OmniProTextToVideoNode,
OmniProFirstLastFrameNode,
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
] ]

View File

@ -1,15 +1,10 @@
from io import BytesIO from io import BytesIO
from typing import Optional, Union
import json
import os import os
import time
import uuid
from enum import Enum from enum import Enum
from inspect import cleandoc from inspect import cleandoc
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from server import PromptServer
import folder_paths import folder_paths
import base64 import base64
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
def create_input_message_contents( def create_input_message_contents(
cls, cls,
prompt: str, prompt: str,
image: Optional[torch.Tensor] = None, image: torch.Tensor | None = None,
files: Optional[list[InputFileContent]] = None, files: list[InputFileContent] | None = None,
) -> InputMessageContentList: ) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image.""" """Create a list of input message contents from prompt and optional image."""
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
InputTextContent(text=prompt, type="input_text"), InputTextContent(text=prompt, type="input_text"),
] ]
if image is not None: if image is not None:
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
prompt: str, prompt: str,
persist_context: bool = False, persist_context: bool = False,
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
images: Optional[torch.Tensor] = None, images: torch.Tensor | None = None,
files: Optional[list[InputFileContent]] = None, files: list[InputFileContent] | None = None,
advanced_options: Optional[CreateModelResponseProperties] = None, advanced_options: CreateModelResponseProperties | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
status_extractor=lambda response: response.status, status_extractor=lambda response: response.status,
completed_statuses=["incomplete", "completed"] completed_statuses=["incomplete", "completed"]
) )
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
# Update history
render_spec = {
"node_id": cls.hidden.unique_id,
"component": "ChatHistoryWidget",
"props": {
"history": json.dumps(
[
{
"prompt": prompt,
"response": output_text,
"response_id": str(uuid.uuid4()),
"timestamp": time.time(),
}
]
),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
return IO.NodeOutput(output_text)
class OpenAIInputFiles(IO.ComfyNode): class OpenAIInputFiles(IO.ComfyNode):
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
def execute( def execute(
cls, cls,
truncation: bool, truncation: bool,
instructions: Optional[str] = None, instructions: str | None = None,
max_output_tokens: Optional[int] = None, max_output_tokens: int | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
""" """
Configure advanced options for the OpenAI Chat Node. Configure advanced options for the OpenAI Chat Node.

View File

@ -1,6 +1,7 @@
import base64 import base64
from io import BytesIO from io import BytesIO
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollResponse, VeoGenVidPollResponse,
VeoGenVidRequest, VeoGenVidRequest,
VeoGenVidResponse, VeoGenVidResponse,
VeoRequestInstance,
VeoRequestInstanceImage,
VeoRequestParameters,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
) )
class Veo3FirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="api node/video/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the video",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative text prompt to guide what to avoid in the video",
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16"],
default="16:9",
tooltip="Aspect ratio of the output video",
),
IO.Int.Input(
"duration",
default=8,
min=4,
max=8,
step=2,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFF,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation",
),
IO.Image.Input("first_frame", tooltip="Start frame"),
IO.Image.Input("last_frame", tooltip="End frame"),
IO.Combo.Input(
"model",
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
default="veo-3.1-fast-generate",
),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="Generate audio for the video.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
model: str,
generate_audio: bool,
):
model = MODELS_MAP[model]
initial_response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
response_model=VeoGenVidResponse,
data=VeoGenVidRequest(
instances=[
VeoRequestInstance(
prompt=prompt,
image=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
),
lastFrame=VeoRequestInstanceImage(
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
),
),
],
parameters=VeoRequestParameters(
aspectRatio=aspect_ratio,
personGeneration="ALLOW",
durationSeconds=duration,
enhancePrompt=True, # cannot be False for Veo3
seed=seed,
generateAudio=generate_audio,
negativePrompt=negative_prompt,
resolution=resolution,
),
),
)
poll_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
response_model=VeoGenVidPollResponse,
status_extractor=lambda r: "completed" if r.done else "pending",
data=VeoGenVidPollRequest(
operationName=initial_response.name,
),
poll_interval=5.0,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
if poll_response.error:
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
response = poll_response.response
filtered_count = response.raiMediaFilteredCount
if filtered_count:
reasons = response.raiMediaFilteredReasons or []
reason_part = f": {reasons[0]}" if reasons else ""
raise Exception(
f"Content blocked by Google's Responsible AI filters{reason_part} "
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
)
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
class VeoExtension(ComfyExtension): class VeoExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
VeoVideoGenerationNode, VeoVideoGenerationNode,
Veo3VideoGenerationNode, Veo3VideoGenerationNode,
Veo3FirstLastFrameNode,
] ]

View File

@ -36,6 +36,7 @@ from .upload_helpers import (
upload_video_to_comfyapi, upload_video_to_comfyapi,
) )
from .validation_utils import ( from .validation_utils import (
get_image_dimensions,
get_number_of_images, get_number_of_images,
validate_aspect_ratio_string, validate_aspect_ratio_string,
validate_audio_duration, validate_audio_duration,
@ -82,6 +83,7 @@ __all__ = [
"trim_video", "trim_video",
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities
"get_image_dimensions",
"get_number_of_images", "get_number_of_images",
"validate_aspect_ratio_string", "validate_aspect_ratio_string",
"validate_audio_duration", "validate_audio_duration",

View File

@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional, Union from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
@ -48,8 +48,9 @@ async def upload_images_to_comfyapi(
image: torch.Tensor, image: torch.Tensor,
*, *,
max_images: int = 8, max_images: int = 8,
mime_type: Optional[str] = None, mime_type: str | None = None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
show_batch_index: bool = True,
) -> list[str]: ) -> list[str]:
""" """
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
@ -59,11 +60,18 @@ async def upload_images_to_comfyapi(
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1 batch_len = image.shape[0] if is_batch else 1
num_to_upload = min(batch_len, max_images)
batch_start_ts = time.monotonic()
for idx in range(min(batch_len, max_images)): for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type) img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
download_urls.append(url) download_urls.append(url)
return download_urls return download_urls
@ -95,6 +103,7 @@ async def upload_video_to_comfyapi(
container: VideoContainer = VideoContainer.MP4, container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264, codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None, max_duration: Optional[int] = None,
wait_label: str | None = "Uploading",
) -> str: ) -> str:
""" """
Uploads a single video to ComfyUI API and returns its download URL. Uploads a single video to ComfyUI API and returns its download URL.
@ -119,15 +128,16 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec) video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
async def upload_file_to_comfyapi( async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
file_bytes_io: BytesIO, file_bytes_io: BytesIO,
filename: str, filename: str,
upload_mime_type: Optional[str], upload_mime_type: str | None,
wait_label: Optional[str] = "Uploading", wait_label: str | None = "Uploading",
progress_origin_ts: float | None = None,
) -> str: ) -> str:
"""Uploads a single file to ComfyUI API and returns its download URL.""" """Uploads a single file to ComfyUI API and returns its download URL."""
if upload_mime_type is None: if upload_mime_type is None:
@ -148,6 +158,7 @@ async def upload_file_to_comfyapi(
file_bytes_io, file_bytes_io,
content_type=upload_mime_type, content_type=upload_mime_type,
wait_label=wait_label, wait_label=wait_label,
progress_origin_ts=progress_origin_ts,
) )
return create_resp.download_url return create_resp.download_url
@ -155,27 +166,18 @@ async def upload_file_to_comfyapi(
async def upload_file( async def upload_file(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
upload_url: str, upload_url: str,
file: Union[BytesIO, str], file: BytesIO | str,
*, *,
content_type: Optional[str] = None, content_type: str | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: Optional[str] = None, wait_label: str | None = None,
progress_origin_ts: float | None = None,
) -> None: ) -> None:
""" """
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
Args:
cls: Node class (provides auth context + UI progress hooks).
upload_url: Pre-signed PUT URL.
file: BytesIO or path string.
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
max_retries: Maximum retry attempts.
retry_delay: Initial delay in seconds.
retry_backoff: Exponential backoff factor.
wait_label: Progress label shown in Comfy UI.
Raises: Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
""" """
@ -198,7 +200,7 @@ async def upload_file(
attempt = 0 attempt = 0
delay = retry_delay delay = retry_delay
start_ts = time.monotonic() start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
op_uuid = uuid.uuid4().hex[:8] op_uuid = uuid.uuid4().hex[:8]
while True: while True:
attempt += 1 attempt += 1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,10 @@ import node_helpers
import comfy.utils import comfy.utils
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import comfy.model_management
import torch
import math
import nodes
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode):
encode = execute # TODO: remove encode = execute # TODO: remove
class EmptyFlux2LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyFlux2LatentImage",
display_name="Empty Flux 2 Latent",
category="latent",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": latent})
class FluxGuidance(io.ComfyNode): class FluxGuidance(io.ComfyNode):
@classmethod @classmethod
@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
append = execute # TODO: remove append = execute # TODO: remove
def generalized_time_snr_shift(t, mu: float, sigma: float):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps
class Flux2Scheduler(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Flux2Scheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=4096),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Sigmas.Output(),
],
)
@classmethod
def execute(cls, steps, width, height) -> io.NodeOutput:
seq_len = (width * height / (16 * 16))
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
class FluxExtension(ComfyExtension): class FluxExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension):
FluxDisableGuidance, FluxDisableGuidance,
FluxKontextImageScale, FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod, FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
] ]

View File

@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile
from pathlib import Path from pathlib import Path
from PIL import Image
import numpy as np
import uuid
def normalize_path(path): def normalize_path(path):
return path.replace('\\', '/') return path.replace('\\', '/')
@ -34,58 +38,6 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
def INPUT_TYPES(s):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
input_path = Path(input_dir)
base_path = Path(folder_paths.get_input_directory())
files = [
normalize_path(str(file_path.relative_to(base_path)))
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'}
]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
"image": ("LOAD_3D_ANIMATION", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
@ -120,7 +72,8 @@ class Preview3D():
"model_file": ("STRING", {"default": "", "multiline": False}), "model_file": ("STRING", {"default": "", "multiline": False}),
}, },
"optional": { "optional": {
"camera_info": ("LOAD3D_CAMERA", {}) "camera_info": ("LOAD3D_CAMERA", {}),
"bg_image": ("IMAGE", {})
}} }}
OUTPUT_NODE = True OUTPUT_NODE = True
@ -133,50 +86,33 @@ class Preview3D():
def process(self, model_file, **kwargs): def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None) camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
bg_image_path = None
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
bg_image_path = f"temp/{filename}"
return { return {
"ui": { "ui": {
"result": [model_file, camera_info] "result": [model_file, camera_info, bg_image_path]
}
}
class Preview3DAnimation():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model_file": ("STRING", {"default": "", "multiline": False}),
},
"optional": {
"camera_info": ("LOAD3D_CAMERA", {})
}}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None)
return {
"ui": {
"result": [model_file, camera_info]
} }
} }
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Load3D": Load3D, "Load3D": Load3D,
"Load3DAnimation": Load3DAnimation,
"Preview3D": Preview3D, "Preview3D": Preview3D,
"Preview3DAnimation": Preview3DAnimation
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D", "Load3D": "Load 3D & Animation",
"Load3DAnimation": "Load 3D - Animation", "Preview3D": "Preview 3D & Animation",
"Preview3D": "Preview 3D",
"Preview3DAnimation": "Preview 3D - Animation"
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.71" __version__ = "0.3.75"

View File

@ -137,6 +137,71 @@ def set_user_directory(user_dir: str) -> None:
user_directory = user_dir user_directory = user_dir
# System User Protection - Protects system directories from HTTP endpoint access
# System Users are internal-only users that cannot be accessed via HTTP endpoints.
# They use the '__' prefix convention (similar to Python's private member convention).
SYSTEM_USER_PREFIX = "__"
def get_system_user_directory(name: str = "system") -> str:
"""
Get the path to a System User directory.
System User directories (prefixed with '__') are only accessible via internal API,
not through HTTP endpoints. Use this for storing system-internal data that
should not be exposed to users.
Args:
name: System user name (e.g., "system", "cache"). Must be alphanumeric
with underscores allowed, but cannot start with underscore.
Returns:
Absolute path to the system user directory.
Raises:
ValueError: If name is empty, invalid, or starts with underscore.
Example:
>>> get_system_user_directory("cache")
'/path/to/user/__cache'
"""
if not name or not isinstance(name, str):
raise ValueError("System user name cannot be empty")
if not name.replace("_", "").isalnum():
raise ValueError(f"Invalid system user name: '{name}'")
if name.startswith("_"):
raise ValueError("System user name should not start with underscore")
return os.path.join(get_user_directory(), f"{SYSTEM_USER_PREFIX}{name}")
def get_public_user_directory(user_id: str) -> str | None:
"""
Get the path to a Public User directory for HTTP endpoint access.
This function provides structural security by returning None for any
System User (prefixed with '__'). All HTTP endpoints should use this
function instead of directly constructing user paths.
Args:
user_id: User identifier from HTTP request.
Returns:
Absolute path to the user directory, or None if user_id is invalid
or refers to a System User.
Example:
>>> get_public_user_directory("default")
'/path/to/user/default'
>>> get_public_user_directory("__system")
None
"""
if not user_id or not isinstance(user_id, str):
return None
if user_id.startswith(SYSTEM_USER_PREFIX):
return None
return os.path.join(get_user_directory(), user_id)
#NOTE: used in http server so don't put folders that should not be accessed remotely #NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name: str) -> str | None: def get_directory_by_type(type_name: str) -> str | None:
if type_name == "output": if type_name == "output":

View File

@ -2,17 +2,24 @@ import torch
from PIL import Image from PIL import Image
from comfy.cli_args import args, LatentPreviewMethod from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD from comfy.taesd.taesd import TAESD
from comfy.sd import VAE
import comfy.model_management import comfy.model_management
import folder_paths import folder_paths
import comfy.utils import comfy.utils
import logging import logging
MAX_PREVIEW_RESOLUTION = args.preview_size MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
def preview_to_image(latent_image): def preview_to_image(latent_image, do_scale=True):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 if do_scale:
.mul(0xFF) # to 0..255 latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
) .mul(0xFF) # to 0..255
)
else:
latents_ubyte = (latent_image.clamp(0, 1)
.mul(0xFF) # to 0..255
)
if comfy.model_management.directml_enabled: if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
@ -35,15 +42,22 @@ class TAESDPreviewerImpl(LatentPreviewer):
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
return preview_to_image(x_sample) return preview_to_image(x_sample)
class TAEHVPreviewerImpl(TAESDPreviewerImpl):
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
return preview_to_image(x_sample, do_scale=False)
class Latent2RGBPreviewer(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None: if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
if self.latent_rgb_factors_reshape is not None:
x0 = self.latent_rgb_factors_reshape(x0)
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None: if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
@ -78,14 +92,19 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD: if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path: if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) if latent_format.taesd_decoder_name in VIDEO_TAES:
previewer = TAESDPreviewerImpl(taesd) taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
taesd.first_stage_model.show_progress_bar = False
previewer = TAEHVPreviewerImpl(taesd)
else:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
else: else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
return previewer return previewer
def prepare_callback(model, steps, x0_output_dict=None): def prepare_callback(model, steps, x0_output_dict=None):

View File

@ -692,8 +692,10 @@ class LoraLoaderModelOnly(LoraLoader):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader: class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod @staticmethod
def vae_list(): def vae_list(s):
vaes = folder_paths.get_filename_list("vae") vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx") approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False sdxl_taesd_enc = False
@ -722,6 +724,11 @@ class VAELoader:
f1_taesd_dec = True f1_taesd_dec = True
elif v.startswith("taef1_decoder."): elif v.startswith("taef1_decoder."):
f1_taesd_enc = True f1_taesd_enc = True
else:
for tae in s.video_taes:
if v.startswith(tae):
vaes.append(v)
if sd1_taesd_dec and sd1_taesd_enc: if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd") vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc: if sdxl_taesd_dec and sdxl_taesd_enc:
@ -765,7 +772,7 @@ class VAELoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}} return {"required": { "vae_name": (s.vae_list(s), )}}
RETURN_TYPES = ("VAE",) RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae" FUNCTION = "load_vae"
@ -776,10 +783,13 @@ class VAELoader:
if vae_name == "pixel_space": if vae_name == "pixel_space":
sd = {} sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0) sd["pixel_space_vae"] = torch.tensor(1.0)
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: elif vae_name in self.image_taes:
sd = self.load_taesd(vae_name) sd = self.load_taesd(vae_name)
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) if os.path.splitext(vae_name)[0] in self.video_taes:
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
@ -929,7 +939,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2278,6 +2288,7 @@ async def init_builtin_extra_nodes():
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py", "nodes_video_model.py",
"nodes_train.py", "nodes_train.py",
"nodes_dataset.py",
"nodes_sag.py", "nodes_sag.py",
"nodes_perpneg.py", "nodes_perpneg.py",
"nodes_stable3d.py", "nodes_stable3d.py",

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.71" version = "0.3.75"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.30.6 comfyui-frontend-package==1.32.10
comfyui-workflow-templates==0.7.9 comfyui-workflow-templates==0.7.25
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde
@ -7,7 +7,7 @@ torchvision
torchaudio torchaudio
numpy>=1.25.0 numpy>=1.25.0
einops einops
transformers>=4.37.2 transformers>=4.50.3
tokenizers>=0.13.3 tokenizers>=0.13.3
sentencepiece sentencepiece
safetensors>=0.4.2 safetensors>=0.4.2

View File

@ -174,7 +174,7 @@ def create_block_external_middleware():
else: else:
response = await handler(request) response = await handler(request)
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
return response return response
return block_external_middleware return block_external_middleware

View File

@ -0,0 +1,193 @@
"""Tests for System User Protection in user_manager.py
Tests cover:
- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers
- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory()
- add_user(): 3rd defense layer - prevents creation of System User names
- Defense layers integration tests
"""
import pytest
from unittest.mock import MagicMock, patch
import tempfile
import folder_paths
from app.user_manager import UserManager
@pytest.fixture
def mock_user_directory():
"""Create a temporary user directory."""
with tempfile.TemporaryDirectory() as temp_dir:
original_dir = folder_paths.get_user_directory()
folder_paths.set_user_directory(temp_dir)
yield temp_dir
folder_paths.set_user_directory(original_dir)
@pytest.fixture
def user_manager(mock_user_directory):
"""Create a UserManager instance for testing."""
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
manager = UserManager()
# Add a default user for testing
manager.users = {"default": "default", "test_user_123": "Test User"}
yield manager
@pytest.fixture
def mock_request():
"""Create a mock request object."""
request = MagicMock()
request.headers = {}
return request
class TestGetRequestUserId:
"""Tests for get_request_user_id() - 1st defense layer.
Verifies:
- System Users (__ prefix) in HTTP header are rejected with KeyError
- Public Users pass through successfully
"""
def test_system_user_raises_error(self, user_manager, mock_request):
"""Test System User in header raises KeyError."""
mock_request.headers = {"comfy-user": "__system"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
with pytest.raises(KeyError, match="Unknown user"):
user_manager.get_request_user_id(mock_request)
def test_system_user_cache_raises_error(self, user_manager, mock_request):
"""Test System User cache raises KeyError."""
mock_request.headers = {"comfy-user": "__cache"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
with pytest.raises(KeyError, match="Unknown user"):
user_manager.get_request_user_id(mock_request)
def test_normal_user_works(self, user_manager, mock_request):
"""Test normal user access works."""
mock_request.headers = {"comfy-user": "default"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
user_id = user_manager.get_request_user_id(mock_request)
assert user_id == "default"
def test_unknown_user_raises_error(self, user_manager, mock_request):
"""Test unknown user raises KeyError."""
mock_request.headers = {"comfy-user": "unknown_user"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
with pytest.raises(KeyError, match="Unknown user"):
user_manager.get_request_user_id(mock_request)
class TestGetRequestUserFilepath:
"""Tests for get_request_user_filepath() - 2nd defense layer.
Verifies:
- Returns None when get_public_user_directory() returns None (System User)
- Acts as backup defense if 1st layer is bypassed
"""
def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory):
"""Test System User returns None (structural blocking)."""
# First, we need to mock get_request_user_id to return System User
# But actually, get_request_user_id will raise KeyError first
# So we test via get_public_user_directory returning None
mock_request.headers = {"comfy-user": "default"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
# Patch get_public_user_directory to return None for testing
with patch.object(folder_paths, 'get_public_user_directory', return_value=None):
result = user_manager.get_request_user_filepath(mock_request, "test.txt")
assert result is None
def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory):
"""Test normal user gets valid filepath."""
mock_request.headers = {"comfy-user": "default"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
path = user_manager.get_request_user_filepath(mock_request, "test.txt")
assert path is not None
assert "default" in path
assert path.endswith("test.txt")
class TestAddUser:
"""Tests for add_user() - 3rd defense layer (creation-time blocking).
Verifies:
- System User name (__ prefix) creation is rejected with ValueError
- Sanitized usernames that become System User are also rejected
"""
def test_system_user_prefix_name_raises(self, user_manager):
"""Test System User prefix in name raises ValueError."""
with pytest.raises(ValueError, match="System User prefix not allowed"):
user_manager.add_user("__system")
def test_system_user_prefix_cache_raises(self, user_manager):
"""Test System User cache prefix raises ValueError."""
with pytest.raises(ValueError, match="System User prefix not allowed"):
user_manager.add_user("__cache")
def test_sanitized_system_user_prefix_raises(self, user_manager):
"""Test sanitized name becoming System User prefix raises ValueError (bypass prevention)."""
# "__test" directly starts with System User prefix
with pytest.raises(ValueError, match="System User prefix not allowed"):
user_manager.add_user("__test")
def test_normal_user_creation(self, user_manager, mock_user_directory):
"""Test normal user creation works."""
user_id = user_manager.add_user("Normal User")
assert user_id is not None
assert not user_id.startswith("__")
assert "Normal-User" in user_id or "Normal_User" in user_id
def test_empty_name_raises(self, user_manager):
"""Test empty name raises ValueError."""
with pytest.raises(ValueError, match="username not provided"):
user_manager.add_user("")
def test_whitespace_only_raises(self, user_manager):
"""Test whitespace-only name raises ValueError."""
with pytest.raises(ValueError, match="username not provided"):
user_manager.add_user(" ")
class TestDefenseLayers:
"""Integration tests for all three defense layers.
Verifies:
- Each defense layer blocks System Users independently
- System User bypass is impossible through any layer
"""
def test_layer1_get_request_user_id(self, user_manager, mock_request):
"""Test 1st defense layer blocks System Users."""
mock_request.headers = {"comfy-user": "__system"}
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
with pytest.raises(KeyError):
user_manager.get_request_user_id(mock_request)
def test_layer2_get_public_user_directory(self):
"""Test 2nd defense layer blocks System Users."""
result = folder_paths.get_public_user_directory("__system")
assert result is None
def test_layer3_add_user(self, user_manager):
"""Test 3rd defense layer blocks System User creation."""
with pytest.raises(ValueError):
user_manager.add_user("__system")

View File

@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self): def test_all_layers_standard(self):
"""Test that model with no quantization works normally""" """Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model # Create model
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops({}))
# Initialize weights manually # Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision # Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Create model and load state dict (strict=False because custom loading pops keys) # Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict1, strict=False) model.load_state_dict(state_dict1, strict=False)
# Save state dict # Save state dict
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA) # Add a weight function (simulating LoRA)
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict # Create state dict
state_dict = { state_dict = {
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)

View File

@ -0,0 +1,206 @@
"""Tests for System User Protection in folder_paths.py
Tests cover:
- get_system_user_directory(): Internal API for custom nodes to access System User directories
- get_public_user_directory(): HTTP endpoint access with System User blocking
- Backward compatibility: Existing APIs unchanged
- Security: Path traversal and injection prevention
"""
import pytest
import os
import tempfile
from folder_paths import (
get_system_user_directory,
get_public_user_directory,
get_user_directory,
set_user_directory,
)
@pytest.fixture(scope="module")
def mock_user_directory():
"""Create a temporary user directory for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
original_dir = get_user_directory()
set_user_directory(temp_dir)
yield temp_dir
set_user_directory(original_dir)
class TestGetSystemUserDirectory:
"""Tests for get_system_user_directory() - internal API for System User directories.
Verifies:
- Custom nodes can access System User directories via internal API
- Input validation prevents path traversal attacks
"""
def test_default_name(self, mock_user_directory):
"""Test default 'system' name."""
path = get_system_user_directory()
assert path.endswith("__system")
assert mock_user_directory in path
def test_custom_name(self, mock_user_directory):
"""Test custom system user name."""
path = get_system_user_directory("cache")
assert path.endswith("__cache")
assert "__cache" in path
def test_name_with_underscore(self, mock_user_directory):
"""Test name with underscore in middle."""
path = get_system_user_directory("my_cache")
assert "__my_cache" in path
def test_empty_name_raises(self):
"""Test empty name raises ValueError."""
with pytest.raises(ValueError, match="cannot be empty"):
get_system_user_directory("")
def test_none_name_raises(self):
"""Test None name raises ValueError."""
with pytest.raises(ValueError, match="cannot be empty"):
get_system_user_directory(None)
def test_name_starting_with_underscore_raises(self):
"""Test name starting with underscore raises ValueError."""
with pytest.raises(ValueError, match="should not start with underscore"):
get_system_user_directory("_system")
def test_path_traversal_raises(self):
"""Test path traversal attempt raises ValueError (security)."""
with pytest.raises(ValueError, match="Invalid system user name"):
get_system_user_directory("../escape")
def test_path_traversal_middle_raises(self):
"""Test path traversal in middle raises ValueError (security)."""
with pytest.raises(ValueError, match="Invalid system user name"):
get_system_user_directory("system/../other")
def test_special_chars_raise(self):
"""Test special characters raise ValueError (security)."""
with pytest.raises(ValueError, match="Invalid system user name"):
get_system_user_directory("system!")
def test_returns_absolute_path(self, mock_user_directory):
"""Test returned path is absolute."""
path = get_system_user_directory("test")
assert os.path.isabs(path)
class TestGetPublicUserDirectory:
"""Tests for get_public_user_directory() - HTTP endpoint access with System User blocking.
Verifies:
- System Users (__ prefix) return None, blocking HTTP access
- Public Users get valid paths
- New endpoints using this function are automatically protected
"""
def test_normal_user(self, mock_user_directory):
"""Test normal user returns valid path."""
path = get_public_user_directory("default")
assert path is not None
assert "default" in path
assert mock_user_directory in path
def test_system_user_returns_none(self):
"""Test System User (__ prefix) returns None - blocks HTTP access."""
assert get_public_user_directory("__system") is None
def test_system_user_cache_returns_none(self):
"""Test System User cache returns None."""
assert get_public_user_directory("__cache") is None
def test_empty_user_returns_none(self):
"""Test empty user returns None."""
assert get_public_user_directory("") is None
def test_none_user_returns_none(self):
"""Test None user returns None."""
assert get_public_user_directory(None) is None
def test_header_injection_returns_none(self):
"""Test header injection attempt returns None (security)."""
assert get_public_user_directory("__system\r\nX-Injected: true") is None
def test_null_byte_injection_returns_none(self):
"""Test null byte injection handling (security)."""
# Note: startswith check happens before any path operations
result = get_public_user_directory("user\x00__system")
# This should return a path since it doesn't start with __
# The actual security comes from the path not being __*
assert result is not None or result is None # Depends on validation
def test_path_traversal_attempt(self, mock_user_directory):
"""Test path traversal attempt handling."""
# This function doesn't validate paths, only reserved prefix
# Path traversal should be handled by the caller
path = get_public_user_directory("../../../etc/passwd")
# Returns path but doesn't start with __, so not None
# Actual path validation happens in user_manager
assert path is not None or "__" not in "../../../etc/passwd"
def test_returns_absolute_path(self, mock_user_directory):
"""Test returned path is absolute."""
path = get_public_user_directory("testuser")
assert path is not None
assert os.path.isabs(path)
class TestBackwardCompatibility:
"""Tests for backward compatibility with existing APIs.
Verifies:
- get_user_directory() API unchanged
- Existing user data remains accessible
"""
def test_get_user_directory_unchanged(self, mock_user_directory):
"""Test get_user_directory() still works as before."""
user_dir = get_user_directory()
assert user_dir is not None
assert os.path.isabs(user_dir)
assert user_dir == mock_user_directory
def test_existing_user_accessible(self, mock_user_directory):
"""Test existing users can access their directories."""
path = get_public_user_directory("default")
assert path is not None
assert "default" in path
class TestEdgeCases:
"""Tests for edge cases in System User detection.
Verifies:
- Only __ prefix is blocked (not _, not middle __)
- Bypass attempts are prevented
"""
def test_prefix_only(self):
"""Test prefix-only string is blocked."""
assert get_public_user_directory("__") is None
def test_single_underscore_allowed(self):
"""Test single underscore prefix is allowed (not System User)."""
path = get_public_user_directory("_system")
assert path is not None
assert "_system" in path
def test_triple_underscore_blocked(self):
"""Test triple underscore is blocked (starts with __)."""
assert get_public_user_directory("___system") is None
def test_underscore_in_middle_allowed(self):
"""Test underscore in middle is allowed."""
path = get_public_user_directory("my__system")
assert path is not None
assert "my__system" in path
def test_leading_space_allowed(self):
"""Test leading space + prefix is allowed (doesn't start with __)."""
path = get_public_user_directory(" __system")
assert path is not None

View File

@ -0,0 +1,375 @@
"""E2E Tests for System User Protection HTTP Endpoints
Tests cover:
- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move)
- User creation blocking: System User names cannot be created via POST /users
- Backward compatibility: Public Users work as before
- Custom node scenario: Internal API works while HTTP is blocked
- Structural security: get_public_user_directory() provides automatic protection
"""
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
from unittest.mock import patch
import folder_paths
@pytest.fixture
def mock_user_directory(tmp_path):
"""Create a temporary user directory."""
original_dir = folder_paths.get_user_directory()
folder_paths.set_user_directory(str(tmp_path))
yield tmp_path
folder_paths.set_user_directory(original_dir)
@pytest.fixture
def user_manager_multi_user(mock_user_directory):
"""Create UserManager in multi-user mode."""
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
um = UserManager()
# Add test users
um.users = {"default": "default", "test_user_123": "Test User"}
yield um
@pytest.fixture
def app_multi_user(user_manager_multi_user):
"""Create app with multi-user mode enabled."""
app = web.Application()
routes = web.RouteTableDef()
user_manager_multi_user.add_routes(routes)
app.add_routes(routes)
return app
class TestSystemUserEndpointBlocking:
"""E2E tests for System User blocking on all HTTP endpoints.
Verifies:
- GET /userdata blocked for System Users
- POST /userdata blocked for System Users
- DELETE /userdata blocked for System Users
- POST /userdata/.../move/... blocked for System Users
"""
@pytest.mark.asyncio
async def test_userdata_get_blocks_system_user(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
GET /userdata with System User header should be blocked.
"""
# Create test directory for System User (simulating internal creation)
system_user_dir = mock_user_directory / "__system"
system_user_dir.mkdir()
(system_user_dir / "secret.txt").write_text("sensitive data")
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
# Attempt to access System User's data via HTTP
resp = await client.get(
"/userdata?dir=.",
headers={"comfy-user": "__system"}
)
# Should be blocked (403 Forbidden or similar error)
assert resp.status in [400, 403, 500], \
f"System User access should be blocked, got {resp.status}"
@pytest.mark.asyncio
async def test_userdata_post_blocks_system_user(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
POST /userdata with System User header should be blocked.
"""
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.post(
"/userdata/test.txt",
headers={"comfy-user": "__system"},
data=b"malicious content"
)
assert resp.status in [400, 403, 500], \
f"System User write should be blocked, got {resp.status}"
# Verify no file was created
assert not (mock_user_directory / "__system" / "test.txt").exists()
@pytest.mark.asyncio
async def test_userdata_delete_blocks_system_user(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
DELETE /userdata with System User header should be blocked.
"""
# Create a file in System User directory
system_user_dir = mock_user_directory / "__system"
system_user_dir.mkdir()
secret_file = system_user_dir / "secret.txt"
secret_file.write_text("do not delete")
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.delete(
"/userdata/secret.txt",
headers={"comfy-user": "__system"}
)
assert resp.status in [400, 403, 500], \
f"System User delete should be blocked, got {resp.status}"
# Verify file still exists
assert secret_file.exists()
@pytest.mark.asyncio
async def test_v2_userdata_blocks_system_user(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
GET /v2/userdata with System User header should be blocked.
"""
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.get(
"/v2/userdata",
headers={"comfy-user": "__system"}
)
assert resp.status in [400, 403, 500], \
f"System User v2 access should be blocked, got {resp.status}"
@pytest.mark.asyncio
async def test_move_userdata_blocks_system_user(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
POST /userdata/{file}/move/{dest} with System User header should be blocked.
"""
system_user_dir = mock_user_directory / "__system"
system_user_dir.mkdir()
(system_user_dir / "source.txt").write_text("sensitive data")
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.post(
"/userdata/source.txt/move/dest.txt",
headers={"comfy-user": "__system"}
)
assert resp.status in [400, 403, 500], \
f"System User move should be blocked, got {resp.status}"
# Verify source file still exists (move was blocked)
assert (system_user_dir / "source.txt").exists()
class TestSystemUserCreationBlocking:
"""E2E tests for blocking System User name creation via POST /users.
Verifies:
- POST /users returns 400 for System User name (not 500)
"""
@pytest.mark.asyncio
async def test_post_users_blocks_system_user_name(
self, aiohttp_client, app_multi_user
):
"""POST /users with System User name should return 400 Bad Request."""
client = await aiohttp_client(app_multi_user)
resp = await client.post(
"/users",
json={"username": "__system"}
)
assert resp.status == 400, \
f"System User creation should return 400, got {resp.status}"
@pytest.mark.asyncio
async def test_post_users_blocks_system_user_prefix_variations(
self, aiohttp_client, app_multi_user
):
"""POST /users with any System User prefix variation should return 400 Bad Request."""
client = await aiohttp_client(app_multi_user)
system_user_names = ["__system", "__cache", "__config", "__anything"]
for name in system_user_names:
resp = await client.post("/users", json={"username": name})
assert resp.status == 400, \
f"System User name '{name}' should return 400, got {resp.status}"
class TestPublicUserStillWorks:
"""E2E tests for backward compatibility - Public Users should work as before.
Verifies:
- Public Users can access their data via HTTP
- Public Users can create files via HTTP
"""
@pytest.mark.asyncio
async def test_public_user_can_access_userdata(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
Public Users should still be able to access their data.
"""
# Create test directory for Public User
user_dir = mock_user_directory / "default"
user_dir.mkdir()
test_dir = user_dir / "workflows"
test_dir.mkdir()
(test_dir / "test.json").write_text('{"test": true}')
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.get(
"/userdata?dir=workflows",
headers={"comfy-user": "default"}
)
assert resp.status == 200
data = await resp.json()
assert "test.json" in data
@pytest.mark.asyncio
async def test_public_user_can_create_files(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
Public Users should still be able to create files.
"""
# Create user directory
user_dir = mock_user_directory / "default"
user_dir.mkdir()
client = await aiohttp_client(app_multi_user)
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.post(
"/userdata/newfile.txt",
headers={"comfy-user": "default"},
data=b"user content"
)
assert resp.status == 200
assert (user_dir / "newfile.txt").exists()
class TestCustomNodeScenario:
"""Tests for custom node use case: internal API access vs HTTP blocking.
Verifies:
- Internal API (get_system_user_directory) works for custom nodes
- HTTP endpoint cannot access data created via internal API
"""
def test_internal_api_can_access_system_user(self, mock_user_directory):
"""
Internal API (get_system_user_directory) should work for custom nodes.
"""
# Custom node uses internal API
system_path = folder_paths.get_system_user_directory("mynode_config")
assert system_path is not None
assert "__mynode_config" in system_path
# Can create and write to System User directory
os.makedirs(system_path, exist_ok=True)
config_file = os.path.join(system_path, "settings.json")
with open(config_file, "w") as f:
f.write('{"api_key": "secret"}')
assert os.path.exists(config_file)
@pytest.mark.asyncio
async def test_http_cannot_access_internal_data(
self, aiohttp_client, app_multi_user, mock_user_directory
):
"""
HTTP endpoint cannot access data created via internal API.
"""
# Custom node creates data via internal API
system_path = folder_paths.get_system_user_directory("mynode_config")
os.makedirs(system_path, exist_ok=True)
with open(os.path.join(system_path, "secret.json"), "w") as f:
f.write('{"api_key": "secret"}')
client = await aiohttp_client(app_multi_user)
# Attacker tries to access via HTTP
with patch('app.user_manager.args') as mock_args:
mock_args.multi_user = True
resp = await client.get(
"/userdata/secret.json",
headers={"comfy-user": "__mynode_config"}
)
# Should be blocked
assert resp.status in [400, 403, 500]
class TestStructuralSecurity:
"""Tests for structural security pattern.
Verifies:
- get_public_user_directory() automatically blocks System Users
- New endpoints using this function are automatically protected
"""
def test_get_public_user_directory_blocks_system_user(self):
"""
Any code using get_public_user_directory() is automatically protected.
"""
# This is the structural security - any new endpoint using this function
# will automatically block System Users
assert folder_paths.get_public_user_directory("__system") is None
assert folder_paths.get_public_user_directory("__cache") is None
assert folder_paths.get_public_user_directory("__anything") is None
# Public Users work
assert folder_paths.get_public_user_directory("default") is not None
assert folder_paths.get_public_user_directory("user123") is not None
def test_structural_security_pattern(self, mock_user_directory):
"""
Demonstrate the structural security pattern for new endpoints.
Any new endpoint should follow this pattern:
1. Get user from request
2. Use get_public_user_directory() - automatically blocks System Users
3. If None, return error
"""
def new_endpoint_handler(user_id: str) -> str | None:
"""Example of how new endpoints should be implemented."""
user_path = folder_paths.get_public_user_directory(user_id)
if user_path is None:
return None # Blocked
return user_path
# System Users are automatically blocked
assert new_endpoint_handler("__system") is None
assert new_endpoint_handler("__secret") is None
# Public Users work
assert new_endpoint_handler("default") is not None