mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-13 03:52:30 +08:00
Merge branch 'Comfy-Org:master' into enable-triton-comfy-kitchen
This commit is contained in:
commit
7ddcd87fef
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -16,7 +16,7 @@ body:
|
|||||||
|
|
||||||
## Very Important
|
## Very Important
|
||||||
|
|
||||||
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored.
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: custom-nodes-test
|
id: custom-nodes-test
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat
|
|||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
|
||||||
|
|
||||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from importlib.metadata import version
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
import app.logger
|
import app.logger
|
||||||
@ -45,25 +45,7 @@ def get_installed_frontend_version():
|
|||||||
|
|
||||||
|
|
||||||
def get_required_frontend_version():
|
def get_required_frontend_version():
|
||||||
"""Get the required frontend version from requirements.txt."""
|
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-frontend-package=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
@ -217,25 +199,7 @@ class FrontendManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_templates_version(cls) -> str:
|
def get_required_templates_version(cls) -> str:
|
||||||
"""Get the required workflow templates version from requirements.txt."""
|
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-workflow-templates=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
|
|||||||
@ -147,6 +147,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
|
|
||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
|
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
@ -160,7 +161,6 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
DynamicVRAM = "dynamic_vram"
|
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
@ -261,4 +261,4 @@ else:
|
|||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
def enables_dynamic_vram():
|
def enables_dynamic_vram():
|
||||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||||
|
|||||||
@ -4,6 +4,25 @@ import comfy.utils
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def is_equal(x, y):
|
||||||
|
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||||
|
return torch.equal(x, y)
|
||||||
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
|
if x.keys() != y.keys():
|
||||||
|
return False
|
||||||
|
return all(is_equal(x[k], y[k]) for k in x)
|
||||||
|
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||||
|
if type(x) is not type(y) or len(x) != len(y):
|
||||||
|
return False
|
||||||
|
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return x == y
|
||||||
|
except Exception:
|
||||||
|
logging.warning("comparison issue with COND")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
|||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond != other.cond:
|
if not is_equal(self.cond, other.cond):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
return # substep from multi-step sampler: keep self._step from the last full step
|
||||||
self._step = int(matches[0].item())
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
|
|||||||
@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
del vshift_msa, vscale_msa
|
del vshift_msa, vscale_msa
|
||||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
del norm_vx
|
del norm_vx
|
||||||
# video cross-attention
|
# video cross-attention
|
||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||||
):
|
):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
|
self_attention_mask=args.get("self_attention_mask"),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
|
"self_attention_mask": self_attention_mask,
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
|||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _log_base(x, base):
|
def _log_base(x, base):
|
||||||
return np.log(x) / np.log(base)
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
x.addcmul_(attn1_input, gate_msa)
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
del attn1_input
|
del attn1_input
|
||||||
|
|
||||||
@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
"""Process input data. Must be implemented by subclasses."""
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Base implementation returns None (no attenuation). Subclasses that
|
||||||
|
support guide-based attention control should override this.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Build self-attention mask for per-guide attenuation
|
||||||
|
self_attention_mask = self._build_guide_self_attention_mask(
|
||||||
|
x, transformer_options, merged_args
|
||||||
|
)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
x = self._process_transformer_blocks(
|
x = self._process_transformer_blocks(
|
||||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
x, context, attention_mask, timestep, pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
|
**merged_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process output
|
# Process output
|
||||||
@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
|
|
||||||
|
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||||
|
# Each entry tracks one guide reference; they are appended in order and
|
||||||
|
# their pre_filter_counts partition the kf_grid_mask.
|
||||||
|
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_entries:
|
||||||
|
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||||
|
if total_pfc != len(kf_grid_mask):
|
||||||
|
raise ValueError(
|
||||||
|
f"guide pre_filter_counts ({total_pfc}) != "
|
||||||
|
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||||
|
)
|
||||||
|
resolved_entries = []
|
||||||
|
offset = 0
|
||||||
|
for entry in guide_entries:
|
||||||
|
pfc = entry["pre_filter_count"]
|
||||||
|
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||||
|
surviving = int(entry_mask.sum().item())
|
||||||
|
resolved_entries.append({
|
||||||
|
**entry,
|
||||||
|
"surviving_count": surviving,
|
||||||
|
})
|
||||||
|
offset += pfc
|
||||||
|
additional_args["resolved_guide_entries"] = resolved_entries
|
||||||
|
|
||||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
|
# Total surviving guide tokens (all guides)
|
||||||
|
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
return x, pixel_coords, additional_args
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||||
|
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||||
|
attention for each guide reference independently.
|
||||||
|
|
||||||
|
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||||
|
spatial masks, or no guide tokens).
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
# AV model: x = [vx, ax]; use vx for token count and device
|
||||||
|
total_tokens = x[0].shape[1]
|
||||||
|
device = x[0].device
|
||||||
|
dtype = x[0].dtype
|
||||||
|
else:
|
||||||
|
total_tokens = x.shape[1]
|
||||||
|
device = x.device
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||||
|
if num_guide_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||||
|
if not resolved_entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if any attenuation is actually needed
|
||||||
|
needs_attenuation = any(
|
||||||
|
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||||
|
for e in resolved_entries
|
||||||
|
)
|
||||||
|
if not needs_attenuation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build per-guide-token weights for all tracked guide tokens.
|
||||||
|
# Guides are appended in order at the end of the sequence.
|
||||||
|
guide_start = total_tokens - num_guide_tokens
|
||||||
|
all_weights = []
|
||||||
|
total_tracked = 0
|
||||||
|
|
||||||
|
for entry in resolved_entries:
|
||||||
|
surviving = entry["surviving_count"]
|
||||||
|
if surviving == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
strength = entry["strength"]
|
||||||
|
pixel_mask = entry.get("pixel_mask")
|
||||||
|
latent_shape = entry.get("latent_shape")
|
||||||
|
|
||||||
|
if pixel_mask is not None and latent_shape is not None:
|
||||||
|
f_lat, h_lat, w_lat = latent_shape
|
||||||
|
per_token = self._downsample_mask_to_latent(
|
||||||
|
pixel_mask.to(device=device, dtype=dtype),
|
||||||
|
f_lat, h_lat, w_lat,
|
||||||
|
)
|
||||||
|
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||||
|
# Collapse batch dim — the mask is assumed identical across the
|
||||||
|
# batch; validate and take the first element to get (1, tokens).
|
||||||
|
if per_token.shape[0] > 1:
|
||||||
|
ref = per_token[0]
|
||||||
|
for bi in range(1, per_token.shape[0]):
|
||||||
|
if not torch.equal(ref, per_token[bi]):
|
||||||
|
logger.warning(
|
||||||
|
"pixel_mask differs across batch elements; "
|
||||||
|
"using first element only."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
per_token = per_token[:1]
|
||||||
|
# `surviving` is the post-grid_mask token count.
|
||||||
|
# Clamp to surviving to handle any mismatch safely.
|
||||||
|
n_weights = min(per_token.shape[1], surviving)
|
||||||
|
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||||
|
else:
|
||||||
|
weights = torch.full(
|
||||||
|
(1, surviving), strength, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weights.append(weights)
|
||||||
|
total_tracked += weights.shape[1]
|
||||||
|
|
||||||
|
if not all_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Concatenate per-token weights for all tracked guides
|
||||||
|
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||||
|
|
||||||
|
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||||
|
if (tracked_weights >= 1.0).all():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build the mask: guide tokens are at the end of the sequence.
|
||||||
|
# Tracked guides come first (in order), untracked follow.
|
||||||
|
return self._build_self_attention_mask(
|
||||||
|
total_tokens, num_guide_tokens, total_tracked,
|
||||||
|
tracked_weights, guide_start, device, dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||||
|
"""Downsample a pixel-space mask to per-token latent weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||||
|
f_lat: Number of latent frames (pre-dilation original count).
|
||||||
|
h_lat: Latent height (pre-dilation original height).
|
||||||
|
w_lat: Latent width (pre-dilation original width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||||
|
"""
|
||||||
|
b = mask.shape[0]
|
||||||
|
f_pix = mask.shape[2]
|
||||||
|
|
||||||
|
# Spatial downsampling: area interpolation per frame
|
||||||
|
spatial_down = torch.nn.functional.interpolate(
|
||||||
|
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||||
|
size=(h_lat, w_lat),
|
||||||
|
mode="area",
|
||||||
|
)
|
||||||
|
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||||
|
|
||||||
|
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||||
|
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||||
|
first_frame = spatial_down[:, :, :1, :, :]
|
||||||
|
if f_pix > 1 and f_lat > 1:
|
||||||
|
remaining_pix = f_pix - 1
|
||||||
|
remaining_lat = f_lat - 1
|
||||||
|
t = remaining_pix // remaining_lat
|
||||||
|
if t < 1:
|
||||||
|
# Fewer pixel frames than latent frames — upsample by repeating
|
||||||
|
# the available pixel frames via nearest interpolation.
|
||||||
|
rest_flat = rearrange(
|
||||||
|
spatial_down[:, :, 1:, :, :],
|
||||||
|
"b 1 f h w -> (b h w) 1 f",
|
||||||
|
)
|
||||||
|
rest_up = torch.nn.functional.interpolate(
|
||||||
|
rest_flat, size=remaining_lat, mode="nearest",
|
||||||
|
)
|
||||||
|
rest = rearrange(
|
||||||
|
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||||
|
b=b, h=h_lat, w=w_lat,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Trim trailing pixel frames that don't fill a complete group
|
||||||
|
usable = remaining_lat * t
|
||||||
|
rest = rearrange(
|
||||||
|
spatial_down[:, :, 1:1 + usable, :, :],
|
||||||
|
"b 1 (f t) h w -> b 1 f t h w",
|
||||||
|
t=t,
|
||||||
|
)
|
||||||
|
rest = rest.mean(dim=3)
|
||||||
|
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||||
|
elif f_lat > 1:
|
||||||
|
# Single pixel frame but multiple latent frames — repeat the
|
||||||
|
# single frame across all latent frames.
|
||||||
|
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||||
|
else:
|
||||||
|
latent_mask = first_frame
|
||||||
|
|
||||||
|
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||||
|
tracked_weights, guide_start, device, dtype):
|
||||||
|
"""Build a log-space additive self-attention bias mask.
|
||||||
|
|
||||||
|
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||||
|
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total sequence length.
|
||||||
|
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||||
|
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||||
|
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||||
|
guide_start: Index where guide tokens begin in the sequence.
|
||||||
|
device: Target device.
|
||||||
|
dtype: Target dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||||
|
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||||
|
"""
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||||
|
tracked_end = guide_start + tracked_count
|
||||||
|
|
||||||
|
# Convert weights to log-space bias
|
||||||
|
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||||
|
log_w = torch.full_like(w, finfo.min)
|
||||||
|
positive_mask = w > 0
|
||||||
|
if positive_mask.any():
|
||||||
|
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||||
|
|
||||||
|
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||||
|
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||||
|
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||||
|
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -18,6 +18,8 @@ import comfy.patcher_extension
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
from ..sdpose import HeatmapHead
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
Any module where forward() takes timestep embeddings as a second argument.
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
@ -441,6 +443,7 @@ class UNetModel(nn.Module):
|
|||||||
disable_temporal_crossattention=False,
|
disable_temporal_crossattention=False,
|
||||||
max_ddpm_temb_period=10000,
|
max_ddpm_temb_period=10000,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
heatmap_head=False,
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops,
|
operations=ops,
|
||||||
):
|
):
|
||||||
@ -827,6 +830,9 @@ class UNetModel(nn.Module):
|
|||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if heatmap_head:
|
||||||
|
self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations)
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
|
|||||||
130
comfy/ldm/modules/sdpose.py
Normal file
130
comfy/ldm/modules/sdpose.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
|
||||||
|
class HeatmapHead(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=640,
|
||||||
|
out_channels=133,
|
||||||
|
input_size=(768, 1024),
|
||||||
|
heatmap_scale=4,
|
||||||
|
deconv_out_channels=(640,),
|
||||||
|
deconv_kernel_sizes=(4,),
|
||||||
|
conv_out_channels=(640,),
|
||||||
|
conv_kernel_sizes=(1,),
|
||||||
|
final_layer_kernel_size=1,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale)
|
||||||
|
self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32)
|
||||||
|
|
||||||
|
# Deconv layers
|
||||||
|
if deconv_out_channels:
|
||||||
|
deconv_layers = []
|
||||||
|
for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes):
|
||||||
|
if kernel_size == 4:
|
||||||
|
padding, output_padding = 1, 0
|
||||||
|
elif kernel_size == 3:
|
||||||
|
padding, output_padding = 1, 1
|
||||||
|
elif kernel_size == 2:
|
||||||
|
padding, output_padding = 0, 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported kernel size {kernel_size}')
|
||||||
|
|
||||||
|
deconv_layers.extend([
|
||||||
|
operations.ConvTranspose2d(in_channels, out_ch, kernel_size,
|
||||||
|
stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype),
|
||||||
|
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||||
|
torch.nn.SiLU(inplace=True)
|
||||||
|
])
|
||||||
|
in_channels = out_ch
|
||||||
|
self.deconv_layers = torch.nn.Sequential(*deconv_layers)
|
||||||
|
else:
|
||||||
|
self.deconv_layers = torch.nn.Identity()
|
||||||
|
|
||||||
|
# Conv layers
|
||||||
|
if conv_out_channels:
|
||||||
|
conv_layers = []
|
||||||
|
for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes):
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
conv_layers.extend([
|
||||||
|
operations.Conv2d(in_channels, out_ch, kernel_size,
|
||||||
|
stride=1, padding=padding, device=device, dtype=dtype),
|
||||||
|
torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype),
|
||||||
|
torch.nn.SiLU(inplace=True)
|
||||||
|
])
|
||||||
|
in_channels = out_ch
|
||||||
|
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
||||||
|
else:
|
||||||
|
self.conv_layers = torch.nn.Identity()
|
||||||
|
|
||||||
|
self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x): # Decode heatmaps to keypoints
|
||||||
|
heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x)))
|
||||||
|
heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W)
|
||||||
|
B, K, H, W = heatmaps_np.shape
|
||||||
|
|
||||||
|
batch_keypoints = []
|
||||||
|
batch_scores = []
|
||||||
|
|
||||||
|
for b in range(B):
|
||||||
|
hm = heatmaps_np[b].copy() # (K, H, W)
|
||||||
|
|
||||||
|
# --- vectorised argmax ---
|
||||||
|
flat = hm.reshape(K, -1)
|
||||||
|
idx = np.argmax(flat, axis=1)
|
||||||
|
scores = flat[np.arange(K), idx].copy()
|
||||||
|
y_locs, x_locs = np.unravel_index(idx, (H, W))
|
||||||
|
keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space
|
||||||
|
invalid = scores <= 0.
|
||||||
|
keypoints[invalid] = -1
|
||||||
|
|
||||||
|
# --- DARK sub-pixel refinement (UDP) ---
|
||||||
|
# 1. Gaussian blur with max-preserving normalisation
|
||||||
|
border = 5 # (kernel-1)//2 for kernel=11
|
||||||
|
for k in range(K):
|
||||||
|
origin_max = np.max(hm[k])
|
||||||
|
dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32)
|
||||||
|
dr[border:-border, border:-border] = hm[k].copy()
|
||||||
|
dr = gaussian_filter(dr, sigma=2.0)
|
||||||
|
hm[k] = dr[border:-border, border:-border].copy()
|
||||||
|
cur_max = np.max(hm[k])
|
||||||
|
if cur_max > 0:
|
||||||
|
hm[k] *= origin_max / cur_max
|
||||||
|
# 2. Log-space for Taylor expansion
|
||||||
|
np.clip(hm, 1e-3, 50., hm)
|
||||||
|
np.log(hm, hm)
|
||||||
|
# 3. Hessian-based Newton step
|
||||||
|
hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten()
|
||||||
|
index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2)
|
||||||
|
index += (W + 2) * (H + 2) * np.arange(0, K)
|
||||||
|
index = index.astype(int).reshape(-1, 1)
|
||||||
|
i_ = hm_pad[index]
|
||||||
|
ix1 = hm_pad[index + 1]
|
||||||
|
iy1 = hm_pad[index + W + 2]
|
||||||
|
ix1y1 = hm_pad[index + W + 3]
|
||||||
|
ix1_y1_ = hm_pad[index - W - 3]
|
||||||
|
ix1_ = hm_pad[index - 1]
|
||||||
|
iy1_ = hm_pad[index - 2 - W]
|
||||||
|
dx = 0.5 * (ix1 - ix1_)
|
||||||
|
dy = 0.5 * (iy1 - iy1_)
|
||||||
|
derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1)
|
||||||
|
dxx = ix1 - 2 * i_ + ix1_
|
||||||
|
dyy = iy1 - 2 * i_ + iy1_
|
||||||
|
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
||||||
|
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2)
|
||||||
|
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
||||||
|
keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1)
|
||||||
|
|
||||||
|
# --- restore to input image space ---
|
||||||
|
keypoints = keypoints * self.scale_factor
|
||||||
|
keypoints[invalid] = -1
|
||||||
|
|
||||||
|
batch_keypoints.append(keypoints)
|
||||||
|
batch_scores.append(scores)
|
||||||
|
|
||||||
|
return batch_keypoints, batch_scores
|
||||||
@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class SCAILWanModel(WanModel):
|
||||||
|
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
|
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = torch.cat((reference_latent, x), dim=2)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
scail_pose_seq_len = 0
|
||||||
|
if pose_latents is not None:
|
||||||
|
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||||
|
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||||
|
scail_pose_seq_len = scail_x.shape[1]
|
||||||
|
x = torch.cat([x, scail_x], dim=1)
|
||||||
|
del scail_x
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.cat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
if scail_pose_seq_len > 0:
|
||||||
|
x = x[:, :-scail_pose_seq_len]
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = x[:, :, reference_latent.shape[2]:]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||||
|
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if pose_latents is None:
|
||||||
|
return main_freqs
|
||||||
|
|
||||||
|
ref_t_patches = 0
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||||
|
|
||||||
|
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||||
|
|
||||||
|
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||||
|
h_scale = h / H_pose
|
||||||
|
w_scale = w / W_pose
|
||||||
|
|
||||||
|
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||||
|
h_shift = (h_scale - 1) / 2
|
||||||
|
w_shift = (w_scale - 1) / 2
|
||||||
|
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||||
|
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||||
|
|
||||||
|
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||||
|
|
||||||
|
t_len = t
|
||||||
|
if time_dim_concat is not None:
|
||||||
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
|
t_len = x.shape[2]
|
||||||
|
|
||||||
|
reference_latent = None
|
||||||
|
if "reference_latent" in kwargs:
|
||||||
|
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||||
|
t_len += reference_latent.shape[2]
|
||||||
|
|
||||||
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||||
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|||||||
@ -459,6 +459,7 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
image_channels=3,
|
image_channels=3,
|
||||||
|
conv_out_channels=3,
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -474,7 +475,7 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales, self.temperal_downsample, dropout)
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@ -484,7 +485,7 @@ class WanVAE(nn.Module):
|
|||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.decoder)
|
feat_map = [None] * count_conv3d(self.encoder)
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
|
|||||||
@ -337,6 +337,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||||
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class ModelType(Enum):
|
|||||||
FLUX = 8
|
FLUX = 8
|
||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
FLOW_COSMOS = 10
|
FLOW_COSMOS = 10
|
||||||
|
IMG_TO_IMG_FLOW = 11
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.FLOW_COSMOS:
|
elif model_type == ModelType.FLOW_COSMOS:
|
||||||
c = comfy.model_sampling.COSMOS_RFLOW
|
c = comfy.model_sampling.COSMOS_RFLOW
|
||||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
|
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||||
|
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -922,6 +925,25 @@ class Flux(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class LongCatImage(Flux):
|
||||||
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
rope_opts = transformer_options.get("rope_options", {})
|
||||||
|
rope_opts = dict(rope_opts)
|
||||||
|
rope_opts.setdefault("shift_t", 1.0)
|
||||||
|
rope_opts.setdefault("shift_y", 512.0)
|
||||||
|
rope_opts.setdefault("shift_x", 512.0)
|
||||||
|
transformer_options["rope_options"] = rope_opts
|
||||||
|
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
out.pop('guidance', None)
|
||||||
|
return out
|
||||||
|
|
||||||
class Flux2(Flux):
|
class Flux2(Flux):
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -971,6 +993,10 @@ class LTXV(BaseModel):
|
|||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
@ -1023,6 +1049,10 @@ class LTXAV(BaseModel):
|
|||||||
if latent_shapes is not None:
|
if latent_shapes is not None:
|
||||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
@ -1466,6 +1496,50 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class WAN21_FlowRVS(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||||
|
model_config.unet_config["model_type"] = "t2v"
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||||
|
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||||
|
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||||
|
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = self.process_latent_in(pose_latents)
|
||||||
|
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||||
|
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||||
|
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
|
||||||
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
@ -496,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "humo"
|
dit_config["model_type"] = "humo"
|
||||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "animate"
|
dit_config["model_type"] = "animate"
|
||||||
|
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "scail"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
@ -509,6 +513,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if ref_conv_weight is not None:
|
if ref_conv_weight is not None:
|
||||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||||
|
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
@ -792,6 +799,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
unet_config["use_temporal_resblock"] = False
|
unet_config["use_temporal_resblock"] = False
|
||||||
unet_config["use_temporal_attention"] = False
|
unet_config["use_temporal_attention"] = False
|
||||||
|
|
||||||
|
heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix)
|
||||||
|
if heatmap_key in state_dict_keys:
|
||||||
|
unet_config["heatmap_head"] = True
|
||||||
|
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||||
@ -1012,7 +1023,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
|
|
||||||
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
|
||||||
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64,
|
||||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
|||||||
@ -180,6 +180,14 @@ def is_ixuca():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_wsl():
|
||||||
|
version = platform.uname().release
|
||||||
|
if version.endswith("-Microsoft"):
|
||||||
|
return True
|
||||||
|
elif version.endswith("microsoft-standard-WSL2"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -350,7 +358,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN'
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
|
||||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
@ -378,7 +386,7 @@ try:
|
|||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if rocm_version >= (7, 0):
|
if rocm_version >= (7, 0):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||||
@ -631,12 +639,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
ram_to_free = ram_required - get_free_ram()
|
ram_to_free = ram_required - get_free_ram()
|
||||||
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#as that works on-demand.
|
||||||
#as that works on-demand.
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
memory_to_free = 0
|
||||||
memory_to_free = 0
|
|
||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
|
|||||||
@ -308,15 +308,22 @@ class ModelPatcher:
|
|||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False):
|
def get_clone_model_override(self):
|
||||||
|
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||||
|
|
||||||
|
def clone(self, disable_dynamic=False, model_override=None):
|
||||||
class_ = self.__class__
|
class_ = self.__class__
|
||||||
model = self.model
|
|
||||||
if self.is_dynamic() and disable_dynamic:
|
if self.is_dynamic() and disable_dynamic:
|
||||||
class_ = ModelPatcher
|
class_ = ModelPatcher
|
||||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
if model_override is None:
|
||||||
model = temp_model_patcher.model
|
if self.cached_patcher_init is None:
|
||||||
|
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||||
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
model_override = temp_model_patcher.get_clone_model_override()
|
||||||
|
if model_override is None:
|
||||||
|
model_override = self.get_clone_model_override()
|
||||||
|
|
||||||
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -325,13 +332,12 @@ class ModelPatcher:
|
|||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
n.backup = self.backup
|
|
||||||
n.object_patches_backup = self.object_patches_backup
|
|
||||||
n.parent = self
|
n.parent = self
|
||||||
n.pinned = self.pinned
|
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
|
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
for k in self.attachments:
|
for k in self.attachments:
|
||||||
@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
del self.model.model_loaded_weight_memory
|
del self.model.model_loaded_weight_memory
|
||||||
if not hasattr(self.model, "dynamic_vbars"):
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
|
self.non_dynamic_delegate_model = None
|
||||||
assert load_device is not None
|
assert load_device is not None
|
||||||
|
|
||||||
def is_dynamic(self):
|
def is_dynamic(self):
|
||||||
@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_non_dynamic_delegate(self):
|
||||||
|
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
||||||
|
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
||||||
|
return model_patcher
|
||||||
|
|
||||||
|
|
||||||
CoreModelPatcher = ModelPatcher
|
CoreModelPatcher = ModelPatcher
|
||||||
|
|||||||
@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
|
|||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
class IMG_TO_IMG_FLOW(CONST):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
|
return 1.0 - latent
|
||||||
|
|
||||||
class COSMOS_RFLOW:
|
class COSMOS_RFLOW:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = (sigma / (sigma + 1))
|
sigma = (sigma / (sigma + 1))
|
||||||
|
|||||||
25
comfy/ops.py
25
comfy/ops.py
@ -19,7 +19,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = to_dequant(x, dtype)
|
x = to_dequant(x, dtype)
|
||||||
if not resident and lowvram_fn is not None:
|
if not resident and lowvram_fn is not None:
|
||||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (isinstance(orig, QuantizedTensor) and
|
if (want_requant and len(fns) == 0 or update_weight):
|
||||||
(want_requant and len(fns) == 0 or update_weight)):
|
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
if isinstance(orig, QuantizedTensor):
|
||||||
if want_requant and len(fns) == 0:
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
#The layer actually wants our freshly saved QT
|
else:
|
||||||
x = y
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||||
elif update_weight:
|
if want_requant and len(fns) == 0:
|
||||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
x = y
|
||||||
if update_weight:
|
if update_weight:
|
||||||
orig.copy_(y)
|
orig.copy_(y)
|
||||||
for f in fns:
|
for f in fns:
|
||||||
@ -296,7 +294,7 @@ class disable_weight_init:
|
|||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -317,7 +315,7 @@ class disable_weight_init:
|
|||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
@ -617,7 +615,8 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim != 2:
|
if input.ndim != 2:
|
||||||
return None
|
return None
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
||||||
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|||||||
@ -66,6 +66,18 @@ def convert_cond(cond):
|
|||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def cond_has_hooks(cond):
|
||||||
|
for c in cond:
|
||||||
|
temp = c[1]
|
||||||
|
if "hooks" in temp:
|
||||||
|
return True
|
||||||
|
if "control" in temp:
|
||||||
|
control = temp["control"]
|
||||||
|
extra_hooks = control.get_extra_hooks()
|
||||||
|
if len(extra_hooks) > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
|
|||||||
@ -946,6 +946,8 @@ class CFGGuider:
|
|||||||
|
|
||||||
def inner_set_conds(self, conds):
|
def inner_set_conds(self, conds):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
|
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
||||||
|
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|||||||
46
comfy/sd.py
46
comfy/sd.py
@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2
|
|||||||
import comfy.text_encoders.newbie
|
import comfy.text_encoders.newbie
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
|
import comfy.text_encoders.longcat_image
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -203,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -232,7 +233,8 @@ class CLIP:
|
|||||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
#Match torch.float32 hardcode upcast in TE implemention
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
self.patcher.set_model_compute_dtype(torch.float32)
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
@ -266,9 +268,9 @@ class CLIP:
|
|||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
self.tokenizer_options = {}
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self, disable_dynamic=False):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
n.patcher = self.patcher.clone()
|
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
|
||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
@ -694,8 +696,9 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||||
|
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
|
||||||
self.pad_channel_value = 1.0
|
self.pad_channel_value = 1.0
|
||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
@ -1159,16 +1162,24 @@ class CLIPType(Enum):
|
|||||||
KANDINSKY5_IMAGE = 23
|
KANDINSKY5_IMAGE = 23
|
||||||
NEWBIE = 24
|
NEWBIE = 24
|
||||||
FLUX2 = 25
|
FLUX2 = 25
|
||||||
|
LONGCAT_IMAGE = 26
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
|
||||||
|
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
|
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if model_options.get("custom_operations", None) is None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
|
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
class TEModel(Enum):
|
class TEModel(Enum):
|
||||||
@ -1273,7 +1284,7 @@ def llama_detect(clip_data):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
@ -1371,6 +1382,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
|
elif clip_type == CLIPType.LONGCAT_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
@ -1490,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -1535,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if output_model:
|
if output_model and out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
|
if output_clip and out[1] is not None:
|
||||||
|
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
@ -1547,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
|
|||||||
disable_dynamic=disable_dynamic)
|
disable_dynamic=disable_dynamic)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
|
||||||
|
embedding_directory=embedding_directory, output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
@ -1632,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5
|
|||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
|
import comfy.text_encoders.longcat_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -525,7 +526,8 @@ class LotusD(SD20):
|
|||||||
}
|
}
|
||||||
|
|
||||||
unet_extra_config = {
|
unet_extra_config = {
|
||||||
"num_classes": 'sequential'
|
"num_classes": 'sequential',
|
||||||
|
"num_head_channels": 64,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
@ -1256,6 +1258,26 @@ class WAN22_T2V(WAN21_T2V):
|
|||||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_FlowRVS(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "flow_rvs",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "scail",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1667,6 +1689,37 @@ class ACEStep15(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
class LongCatImage(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "flux",
|
||||||
|
"guidance_embed": False,
|
||||||
|
"vec_in_dim": None,
|
||||||
|
"context_in_dim": 3584,
|
||||||
|
"txt_ids_dims": [1, 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
|
memory_usage_factor = 2.5
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.LongCatImage(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module):
|
|||||||
return getattr(self, self.lm_model).load_sd(sd)
|
return getattr(self, self.lm_model).load_sd(sd)
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
lm_metadata = token_weight_pairs.get("lm_metadata", {})
|
||||||
constant = self.constant
|
constant = self.constant
|
||||||
if comfy.model_management.should_use_bf16(device):
|
if comfy.model_management.should_use_bf16(device):
|
||||||
constant *= 0.5
|
constant *= 0.5
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
num_tokens += lm_metadata['min_tokens']
|
num_tokens += lm_metadata.get("min_tokens", 0)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||||
|
|||||||
184
comfy/text_encoders/longcat_image.py
Normal file
184
comfy/text_encoders/longcat_image.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import re
|
||||||
|
import numbers
|
||||||
|
import torch
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")]
|
||||||
|
QUOTE_PATTERN = "|".join(
|
||||||
|
[
|
||||||
|
re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2)
|
||||||
|
for q1, q2 in QUOTE_PAIRS
|
||||||
|
]
|
||||||
|
)
|
||||||
|
WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
|
||||||
|
|
||||||
|
|
||||||
|
def split_quotation(prompt):
|
||||||
|
matches = WORD_INTERNAL_QUOTE_RE.findall(prompt)
|
||||||
|
mapping = []
|
||||||
|
for i, word_src in enumerate(set(matches)):
|
||||||
|
word_tgt = "longcat_$##$_longcat" * (i + 1)
|
||||||
|
prompt = prompt.replace(word_src, word_tgt)
|
||||||
|
mapping.append((word_src, word_tgt))
|
||||||
|
|
||||||
|
parts = re.split(f"({QUOTE_PATTERN})", prompt)
|
||||||
|
result = []
|
||||||
|
for part in parts:
|
||||||
|
for word_src, word_tgt in mapping:
|
||||||
|
part = part.replace(word_tgt, word_src)
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
is_quoted = bool(re.match(QUOTE_PATTERN, part))
|
||||||
|
result.append((part, is_quoted))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_length = 512
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
parts = split_quotation(text)
|
||||||
|
all_tokens = []
|
||||||
|
for part_text, is_quoted in parts:
|
||||||
|
if is_quoted:
|
||||||
|
for char in part_text:
|
||||||
|
ids = self.tokenizer(char, add_special_tokens=False)["input_ids"]
|
||||||
|
all_tokens.extend(ids)
|
||||||
|
else:
|
||||||
|
ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"]
|
||||||
|
all_tokens.extend(ids)
|
||||||
|
|
||||||
|
if len(all_tokens) > self.max_length:
|
||||||
|
all_tokens = all_tokens[: self.max_length]
|
||||||
|
logger.warning(f"Truncated prompt to {self.max_length} tokens")
|
||||||
|
|
||||||
|
output = [(t, 1.0) for t in all_tokens]
|
||||||
|
# Pad to max length
|
||||||
|
self.pad_tokens(output, self.max_length - len(output))
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
tokenizer_data=tokenizer_data,
|
||||||
|
name="qwen25_7b",
|
||||||
|
tokenizer=LongCatImageBaseTokenizer,
|
||||||
|
)
|
||||||
|
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||||
|
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
skip_template = False
|
||||||
|
if text.startswith("<|im_start|>"):
|
||||||
|
skip_template = True
|
||||||
|
if text.startswith("<|start_header_id|>"):
|
||||||
|
skip_template = True
|
||||||
|
if text == "":
|
||||||
|
text = " "
|
||||||
|
|
||||||
|
base_tok = getattr(self, "qwen25_7b")
|
||||||
|
if skip_template:
|
||||||
|
tokens = super().tokenize_with_weights(
|
||||||
|
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefix_ids = base_tok.tokenizer(
|
||||||
|
self.longcat_template_prefix, add_special_tokens=False
|
||||||
|
)["input_ids"]
|
||||||
|
suffix_ids = base_tok.tokenizer(
|
||||||
|
self.longcat_template_suffix, add_special_tokens=False
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
prompt_tokens = base_tok.tokenize_with_weights(
|
||||||
|
text, return_word_ids=return_word_ids, **kwargs
|
||||||
|
)
|
||||||
|
prompt_pairs = prompt_tokens[0]
|
||||||
|
|
||||||
|
prefix_pairs = [(t, 1.0) for t in prefix_ids]
|
||||||
|
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
||||||
|
|
||||||
|
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
||||||
|
tokens = {"qwen25_7b": [combined]}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
name="qwen25_7b",
|
||||||
|
clip_model=Qwen25_7BVLIModel,
|
||||||
|
model_options=model_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
if template_end == -1:
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151644 and count_im_start < 2:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 3):
|
||||||
|
if tok_pairs[template_end + 1][0] == 872:
|
||||||
|
if tok_pairs[template_end + 2][0] == 198:
|
||||||
|
template_end += 3
|
||||||
|
|
||||||
|
if template_end == -1:
|
||||||
|
template_end = 0
|
||||||
|
|
||||||
|
suffix_start = None
|
||||||
|
for i in range(len(tok_pairs) - 1, -1, -1):
|
||||||
|
elem = tok_pairs[i][0]
|
||||||
|
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151645:
|
||||||
|
suffix_start = i
|
||||||
|
break
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
|
||||||
|
if "attention_mask" in extra:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask")
|
||||||
|
|
||||||
|
if suffix_start is not None:
|
||||||
|
suffix_len = len(tok_pairs) - suffix_start
|
||||||
|
if suffix_len > 0 and out.shape[1] > suffix_len:
|
||||||
|
out = out[:, :-suffix_len]
|
||||||
|
if "attention_mask" in extra:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(
|
||||||
|
extra["attention_mask"]
|
||||||
|
):
|
||||||
|
extra.pop("attention_mask")
|
||||||
|
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class LongCatImageTEModel_(LongCatImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
|
return LongCatImageTEModel_
|
||||||
@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
|
import itertools
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
constant /= 2.0
|
constant /= 2.0
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
|
||||||
num_tokens = max(num_tokens, 64)
|
|
||||||
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
|
||||||
|
num_tokens = max(num_tokens, 642)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import mmap
|
import mmap
|
||||||
@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
metadata = None
|
metadata = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
if enables_dynamic_vram():
|
if comfy.memory_management.aimdo_enabled:
|
||||||
sd, metadata = load_safetensors(ckpt)
|
sd, metadata = load_safetensors(ckpt)
|
||||||
if not return_metadata:
|
if not return_metadata:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|||||||
@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
socketless: bool=True, default: dict=None, component: str=None):
|
socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
||||||
self.component = component
|
self.component = component
|
||||||
|
self.force_input = force_input
|
||||||
if default is None:
|
if default is None:
|
||||||
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
||||||
|
|
||||||
@ -1234,6 +1235,8 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
d = super().as_dict()
|
d = super().as_dict()
|
||||||
if self.component:
|
if self.component:
|
||||||
d["component"] = self.component
|
d["component"] = self.component
|
||||||
|
if self.force_input is not None:
|
||||||
|
d["forceInput"] = self.force_input
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel):
|
|||||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiThinkingConfig(BaseModel):
|
||||||
|
includeThoughts: bool | None = Field(None)
|
||||||
|
thinkingLevel: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
responseModalities: list[str] | None = Field(None)
|
responseModalities: list[str] | None = Field(None)
|
||||||
imageConfig: GeminiImageConfig | None = Field(None)
|
imageConfig: GeminiImageConfig | None = Field(None)
|
||||||
|
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerateContentRequest(BaseModel):
|
class GeminiImageGenerateContentRequest(BaseModel):
|
||||||
|
|||||||
@ -186,7 +186,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ByteDanceSeedreamNode",
|
node_id="ByteDanceSeedreamNode",
|
||||||
display_name="ByteDance Seedream 5.0",
|
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import (
|
|||||||
GeminiRole,
|
GeminiRole,
|
||||||
GeminiSystemInstructionContent,
|
GeminiSystemInstructionContent,
|
||||||
GeminiTextPart,
|
GeminiTextPart,
|
||||||
|
GeminiThinkingConfig,
|
||||||
Modality,
|
Modality,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = (
|
|||||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$m := widgets.model;
|
||||||
|
$r := widgets.resolution;
|
||||||
|
$isFlash := $contains($m, "nano banana 2");
|
||||||
|
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
||||||
|
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||||
|
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||||
|
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiModel(str, Enum):
|
class GeminiModel(str, Enum):
|
||||||
"""
|
"""
|
||||||
@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
|||||||
input_tokens_price = 2
|
input_tokens_price = 2
|
||||||
output_text_tokens_price = 12.0
|
output_text_tokens_price = 12.0
|
||||||
output_image_tokens_price = 120.0
|
output_image_tokens_price = 120.0
|
||||||
|
elif response.modelVersion == "gemini-3.1-flash-image-preview":
|
||||||
|
input_tokens_price = 0.5
|
||||||
|
output_text_tokens_price = 3.0
|
||||||
|
output_image_tokens_price = 60.0
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||||
@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["gemini-3-pro-image-preview"],
|
options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"],
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$r := widgets.resolution;
|
|
||||||
($contains($r,"1k") or $contains($r,"2k"))
|
|
||||||
? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}}
|
|
||||||
: $contains($r,"4k")
|
|
||||||
? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}}
|
|
||||||
: {"type":"text","text":"Token-based"}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -779,6 +787,8 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||||
|
model = "gemini-3.1-flash-image-preview"
|
||||||
|
|
||||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -815,6 +825,169 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiNanoBanana2(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GeminiNanoBanana2",
|
||||||
|
display_name="Nano Banana 2",
|
||||||
|
category="api node/image/Gemini",
|
||||||
|
description="Generate or edit images synchronously via Google Vertex API.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||||
|
"Include any constraints, styles, or details the model should follow.",
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["Nano Banana 2 (Gemini 3.1 Flash Image)"],
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=42,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||||
|
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||||
|
"Also, changing the model or parameter settings, such as the temperature, "
|
||||||
|
"can cause variations in the response even when you use the same seed value. "
|
||||||
|
"By default, a random seed value is used.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=[
|
||||||
|
"auto",
|
||||||
|
"1:1",
|
||||||
|
"2:3",
|
||||||
|
"3:2",
|
||||||
|
"3:4",
|
||||||
|
"4:3",
|
||||||
|
"4:5",
|
||||||
|
"5:4",
|
||||||
|
"9:16",
|
||||||
|
"16:9",
|
||||||
|
"21:9",
|
||||||
|
# "1:4",
|
||||||
|
# "4:1",
|
||||||
|
# "8:1",
|
||||||
|
# "1:8",
|
||||||
|
],
|
||||||
|
default="auto",
|
||||||
|
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||||
|
"if no image is provided, a 16:9 square is usually generated.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
# "512px",
|
||||||
|
"1K",
|
||||||
|
"2K",
|
||||||
|
"4K",
|
||||||
|
],
|
||||||
|
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"response_modalities",
|
||||||
|
options=["IMAGE", "IMAGE+TEXT"],
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"thinking_level",
|
||||||
|
options=["MINIMAL", "HIGH"],
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional reference image(s). "
|
||||||
|
"To include multiple images, use the Batch Images node (up to 14).",
|
||||||
|
),
|
||||||
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
|
"files",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"system_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
IO.String.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
seed: int,
|
||||||
|
aspect_ratio: str,
|
||||||
|
resolution: str,
|
||||||
|
response_modalities: str,
|
||||||
|
thinking_level: str,
|
||||||
|
images: Input.Image | None = None,
|
||||||
|
files: list[GeminiPart] | None = None,
|
||||||
|
system_prompt: str = "",
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||||
|
model = "gemini-3.1-flash-image-preview"
|
||||||
|
|
||||||
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
|
if images is not None:
|
||||||
|
if get_number_of_images(images) > 14:
|
||||||
|
raise ValueError("The current maximum number of supported images is 14.")
|
||||||
|
parts.extend(await create_image_parts(cls, images))
|
||||||
|
if files is not None:
|
||||||
|
parts.extend(files)
|
||||||
|
|
||||||
|
image_config = GeminiImageConfig(imageSize=resolution)
|
||||||
|
if aspect_ratio != "auto":
|
||||||
|
image_config.aspectRatio = aspect_ratio
|
||||||
|
|
||||||
|
gemini_system_prompt = None
|
||||||
|
if system_prompt:
|
||||||
|
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"),
|
||||||
|
data=GeminiImageGenerateContentRequest(
|
||||||
|
contents=[
|
||||||
|
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||||
|
],
|
||||||
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
|
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||||
|
imageConfig=image_config,
|
||||||
|
thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level),
|
||||||
|
),
|
||||||
|
systemInstruction=gemini_system_prompt,
|
||||||
|
),
|
||||||
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
price_extractor=calculate_tokens_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -822,6 +995,7 @@ class GeminiExtension(ComfyExtension):
|
|||||||
GeminiNode,
|
GeminiNode,
|
||||||
GeminiImage,
|
GeminiImage,
|
||||||
GeminiImage2,
|
GeminiImage2,
|
||||||
|
GeminiNanoBanana2,
|
||||||
GeminiInputFiles,
|
GeminiInputFiles,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class JobStatus:
|
|||||||
|
|
||||||
|
|
||||||
# Media types that can be previewed in the frontend
|
# Media types that can be previewed in the frontend
|
||||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||||
|
|
||||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||||
@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
|
|||||||
normalized[node_id] = normalized_node
|
normalized[node_id] = normalized_node
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
|
||||||
|
TEXT_PREVIEW_MAX_LENGTH = 1024
|
||||||
|
|
||||||
|
|
||||||
|
def _create_text_preview(value: str) -> dict:
|
||||||
|
"""Create a text preview dict with optional truncation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with 'content' and optionally 'truncated' flag
|
||||||
|
"""
|
||||||
|
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
|
||||||
|
return {'content': value}
|
||||||
|
return {
|
||||||
|
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
|
||||||
|
'truncated': True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
normalized = normalize_output_item(item)
|
if not isinstance(item, dict):
|
||||||
if normalized is None:
|
# Handle text outputs (non-dict items like strings or tuples)
|
||||||
continue
|
normalized = normalize_output_item(item)
|
||||||
|
if normalized is None:
|
||||||
|
# Not a 3D file string — check for text preview
|
||||||
|
if media_type == 'text':
|
||||||
|
count += 1
|
||||||
|
if preview_output is None:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
text_value = item[0] if item else ''
|
||||||
|
else:
|
||||||
|
text_value = str(item)
|
||||||
|
text_preview = _create_text_preview(text_value)
|
||||||
|
enriched = {
|
||||||
|
**text_preview,
|
||||||
|
'nodeId': node_id,
|
||||||
|
'mediaType': media_type
|
||||||
|
}
|
||||||
|
if fallback_preview is None:
|
||||||
|
fallback_preview = enriched
|
||||||
|
continue
|
||||||
|
# normalize_output_item returned a dict (e.g. 3D file)
|
||||||
|
item = normalized
|
||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if preview_output is not None:
|
if preview_output is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
if is_previewable(media_type, item):
|
||||||
enriched = {
|
enriched = {
|
||||||
**normalized,
|
**item,
|
||||||
'nodeId': node_id,
|
'nodeId': node_id,
|
||||||
}
|
}
|
||||||
if 'mediaType' not in normalized:
|
if 'mediaType' not in item:
|
||||||
enriched['mediaType'] = media_type
|
enriched['mediaType'] = media_type
|
||||||
if normalized.get('type') == 'output':
|
if item.get('type') == 'output':
|
||||||
preview_output = enriched
|
preview_output = enriched
|
||||||
elif fallback_preview is None:
|
elif fallback_preview is None:
|
||||||
fallback_preview = enriched
|
fallback_preview = enriched
|
||||||
|
|||||||
@ -717,11 +717,11 @@ def _render_shader_batch(
|
|||||||
gl.glUseProgram(0)
|
gl.glUseProgram(0)
|
||||||
|
|
||||||
for tex in input_textures:
|
for tex in input_textures:
|
||||||
gl.glDeleteTextures(tex)
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in output_textures:
|
for tex in output_textures:
|
||||||
gl.glDeleteTextures(tex)
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in ping_pong_textures:
|
for tex in ping_pong_textures:
|
||||||
gl.glDeleteTextures(tex)
|
gl.glDeleteTextures(int(tex))
|
||||||
if fbo is not None:
|
if fbo is not None:
|
||||||
gl.glDeleteFramebuffers(1, [fbo])
|
gl.glDeleteFramebuffers(1, [fbo])
|
||||||
for pp_fbo in ping_pong_fbos:
|
for pp_fbo in ping_pong_fbos:
|
||||||
@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode):
|
|||||||
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
"""Build UI output with input and output images for client-side shader execution."""
|
"""Build UI output with input and output images for client-side shader execution."""
|
||||||
combined_inputs = torch.cat(image_list, dim=0)
|
input_images_ui = []
|
||||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
for img in image_list:
|
||||||
combined_inputs,
|
input_images_ui.extend(ui.ImageSaveHelper.save_images(
|
||||||
filename_prefix="GLSLShader_input",
|
img,
|
||||||
folder_type=io.FolderType.temp,
|
filename_prefix="GLSLShader_input",
|
||||||
cls=None,
|
folder_type=io.FolderType.temp,
|
||||||
compress_level=1,
|
cls=None,
|
||||||
)
|
compress_level=1,
|
||||||
|
))
|
||||||
|
|
||||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||||
output_batch,
|
output_batch,
|
||||||
|
|||||||
@ -248,7 +248,7 @@ class SetClipHooks:
|
|||||||
|
|
||||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
clip = clip.clone()
|
clip = clip.clone(disable_dynamic=True)
|
||||||
if apply_to_conds:
|
if apply_to_conds:
|
||||||
clip.apply_hooks_to_conds = hooks
|
clip.apply_hooks_to_conds = hooks
|
||||||
clip.patcher.forced_hooks = hooks.clone()
|
clip.patcher.forced_hooks = hooks.clone()
|
||||||
|
|||||||
@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
||||||
coords = []
|
coords = []
|
||||||
stride_x = max(1, tile_width - overlap)
|
stride_x = round(max(tile_width * 0.25, tile_width - overlap))
|
||||||
stride_y = max(1, tile_height - overlap)
|
stride_y = round(max(tile_width * 0.25, tile_height - overlap))
|
||||||
|
|
||||||
y = 0
|
y = 0
|
||||||
while y < height:
|
while y < height:
|
||||||
@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
|
||||||
coords = []
|
|
||||||
stride_x = max(1, tile_width - overlap)
|
|
||||||
stride_y = max(1, tile_height - overlap)
|
|
||||||
|
|
||||||
y = 0
|
|
||||||
while y < height:
|
|
||||||
x = 0
|
|
||||||
y_end = min(y + tile_height, height)
|
|
||||||
y_start = max(0, y_end - tile_height)
|
|
||||||
|
|
||||||
while x < width:
|
|
||||||
x_end = min(x + tile_width, width)
|
|
||||||
x_start = max(0, x_end - tile_width)
|
|
||||||
|
|
||||||
coords.append((x_start, y_start, x_end, y_end))
|
|
||||||
|
|
||||||
if x_end >= width:
|
|
||||||
break
|
|
||||||
x += stride_x
|
|
||||||
|
|
||||||
if y_end >= height:
|
|
||||||
break
|
|
||||||
y += stride_y
|
|
||||||
|
|
||||||
return coords
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image_list, final_width, final_height, overlap):
|
def execute(cls, image_list, final_width, final_height, overlap):
|
||||||
w = final_width[0]
|
w = final_width[0]
|
||||||
@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode):
|
|||||||
device = first_tile.device
|
device = first_tile.device
|
||||||
dtype = first_tile.dtype
|
dtype = first_tile.dtype
|
||||||
|
|
||||||
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
|
coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
|
||||||
|
|
||||||
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
||||||
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
||||||
|
|||||||
@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
|||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||||
|
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||||
|
|
||||||
|
Each entry tracks one guide reference for per-reference attention control.
|
||||||
|
Entries are derived independently from each conditioning to avoid cross-contamination.
|
||||||
|
"""
|
||||||
|
new_entry = {
|
||||||
|
"pre_filter_count": pre_filter_count,
|
||||||
|
"strength": strength,
|
||||||
|
"pixel_mask": None,
|
||||||
|
"latent_shape": latent_shape,
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
for cond in (positive, negative):
|
||||||
|
# Read existing entries from this specific conditioning
|
||||||
|
existing = []
|
||||||
|
for t in cond:
|
||||||
|
found = t[1].get("guide_attention_entries", None)
|
||||||
|
if found is not None:
|
||||||
|
existing = found
|
||||||
|
break
|
||||||
|
# Shallow copy and append (no deepcopy needed — entries contain
|
||||||
|
# only scalars and None for pixel_mask at this call site).
|
||||||
|
entries = [*existing, new_entry]
|
||||||
|
results.append(node_helpers.conditioning_set_values(
|
||||||
|
cond, {"guide_attention_entries": entries}
|
||||||
|
))
|
||||||
|
return results[0], results[1]
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
if key in t[1]:
|
if key in t[1]:
|
||||||
@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
scale_factors,
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track this guide for per-reference attention control.
|
||||||
|
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||||
|
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||||
|
positive, negative = _append_guide_attention_entry(
|
||||||
|
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||||
|
)
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
|
|||||||
latent_image = latent_image[:, :, :-num_keyframes]
|
latent_image = latent_image[:, :, :-num_keyframes]
|
||||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
positive = node_helpers.conditioning_set_values(positive, {
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {
|
||||||
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Mahiro CFG",
|
display_name="Positive-Biased Guidance",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
|
|||||||
io.Model.Output(display_name="patched_model"),
|
io.Model.Output(display_name="patched_model"),
|
||||||
],
|
],
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
|
search_aliases=[
|
||||||
|
"mahiro",
|
||||||
|
"mahiro cfg",
|
||||||
|
"similarity-adaptive guidance",
|
||||||
|
"positive-biased cfg",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model) -> io.NodeOutput:
|
def execute(cls, model) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def mahiro_normd(args):
|
def mahiro_normd(args):
|
||||||
scale: float = args['cond_scale']
|
scale: float = args["cond_scale"]
|
||||||
cond_p: torch.Tensor = args['cond_denoised']
|
cond_p: torch.Tensor = args["cond_denoised"]
|
||||||
uncond_p: torch.Tensor = args['uncond_denoised']
|
uncond_p: torch.Tensor = args["uncond_denoised"]
|
||||||
#naive leap
|
# naive leap
|
||||||
leap = cond_p * scale
|
leap = cond_p * scale
|
||||||
#sim with uncond leap
|
# sim with uncond leap
|
||||||
u_leap = uncond_p * scale
|
u_leap = uncond_p * scale
|
||||||
cfg = args["denoised"]
|
cfg = args["denoised"]
|
||||||
merge = (leap + cfg) / 2
|
merge = (leap + cfg) / 2
|
||||||
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
||||||
normm = torch.sqrt(merge.abs()) * merge.sign()
|
normm = torch.sqrt(merge.abs()) * merge.sign()
|
||||||
sim = F.cosine_similarity(normu, normm).mean()
|
sim = F.cosine_similarity(normu, normm).mean()
|
||||||
simsc = 2 * (sim+1)
|
simsc = 2 * (sim + 1)
|
||||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
wm = (simsc * cfg + (4 - simsc) * leap) / 4
|
||||||
return wm
|
return wm
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
|
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
|
||||||
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
|
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
|
|||||||
sampling_type = comfy.model_sampling.X0
|
sampling_type = comfy.model_sampling.X0
|
||||||
elif sampling == "img_to_img":
|
elif sampling == "img_to_img":
|
||||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||||
|
elif sampling == "img_to_img_flow":
|
||||||
|
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||||
|
|
||||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
|
|||||||
node_id="ImageBlur",
|
node_id="ImageBlur",
|
||||||
display_name="Image Blur",
|
display_name="Image Blur",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
essentials_category="Image Tools",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
||||||
@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
|
|||||||
node_id="BatchImagesNode",
|
node_id="BatchImagesNode",
|
||||||
display_name="Batch Images",
|
display_name="Batch Images",
|
||||||
category="image",
|
category="image",
|
||||||
|
essentials_category="Image Tools",
|
||||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Autogrow.Input("images", template=autogrow_template)
|
io.Autogrow.Input("images", template=autogrow_template)
|
||||||
|
|||||||
86
comfy_extras/nodes_resolution.py
Normal file
86
comfy_extras/nodes_resolution.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio(str, Enum):
|
||||||
|
SQUARE = "1:1 (Square)"
|
||||||
|
PHOTO_H = "3:2 (Photo)"
|
||||||
|
STANDARD_H = "4:3 (Standard)"
|
||||||
|
WIDESCREEN_H = "16:9 (Widescreen)"
|
||||||
|
ULTRAWIDE_H = "21:9 (Ultrawide)"
|
||||||
|
PHOTO_V = "2:3 (Portrait Photo)"
|
||||||
|
STANDARD_V = "3:4 (Portrait Standard)"
|
||||||
|
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
|
||||||
|
|
||||||
|
|
||||||
|
ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = {
|
||||||
|
AspectRatio.SQUARE: (1, 1),
|
||||||
|
AspectRatio.PHOTO_H: (3, 2),
|
||||||
|
AspectRatio.STANDARD_H: (4, 3),
|
||||||
|
AspectRatio.WIDESCREEN_H: (16, 9),
|
||||||
|
AspectRatio.ULTRAWIDE_H: (21, 9),
|
||||||
|
AspectRatio.PHOTO_V: (2, 3),
|
||||||
|
AspectRatio.STANDARD_V: (3, 4),
|
||||||
|
AspectRatio.WIDESCREEN_V: (9, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionSelector(io.ComfyNode):
|
||||||
|
"""Calculate width and height from aspect ratio and megapixel target."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResolutionSelector",
|
||||||
|
display_name="Resolution Selector",
|
||||||
|
category="utils",
|
||||||
|
description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=AspectRatio,
|
||||||
|
default=AspectRatio.SQUARE,
|
||||||
|
tooltip="The aspect ratio for the output dimensions.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"megapixels",
|
||||||
|
default=1.0,
|
||||||
|
min=0.1,
|
||||||
|
max=16.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Int.Output(
|
||||||
|
"width", tooltip="Calculated width in pixels (multiple of 8)."
|
||||||
|
),
|
||||||
|
io.Int.Output(
|
||||||
|
"height", tooltip="Calculated height in pixels (multiple of 8)."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
|
||||||
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
total_pixels = megapixels * 1024 * 1024
|
||||||
|
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
|
||||||
|
width = round(w_ratio * scale / 8) * 8
|
||||||
|
height = round(h_ratio * scale / 8) * 8
|
||||||
|
return io.NodeOutput(width, height)
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ResolutionSelector,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ResolutionExtension:
|
||||||
|
return ResolutionExtension()
|
||||||
740
comfy_extras/nodes_sdpose.py
Normal file
740
comfy_extras/nodes_sdpose.py
Normal file
@ -0,0 +1,740 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import colorsys
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_lotus import LotusConditioning
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_keypoints(kp_raw, sc_raw):
|
||||||
|
"""Insert neck keypoint and remap from MMPose to OpenPose ordering.
|
||||||
|
|
||||||
|
Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,).
|
||||||
|
Layout:
|
||||||
|
0-17 body (18 kp, OpenPose order)
|
||||||
|
18-23 feet (6 kp)
|
||||||
|
24-91 face (68 kp)
|
||||||
|
92-112 right hand (21 kp)
|
||||||
|
113-133 left hand (21 kp)
|
||||||
|
"""
|
||||||
|
kp = np.array(kp_raw, dtype=np.float32)
|
||||||
|
sc = np.array(sc_raw, dtype=np.float32)
|
||||||
|
if len(kp) >= 17:
|
||||||
|
neck = (kp[5] + kp[6]) / 2
|
||||||
|
neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0
|
||||||
|
kp = np.insert(kp, 17, neck, axis=0)
|
||||||
|
sc = np.insert(sc, 17, neck_score)
|
||||||
|
mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3])
|
||||||
|
openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17])
|
||||||
|
tmp_kp, tmp_sc = kp.copy(), sc.copy()
|
||||||
|
tmp_kp[openpose_idx] = kp[mmpose_idx]
|
||||||
|
tmp_sc[openpose_idx] = sc[mmpose_idx]
|
||||||
|
kp, sc = tmp_kp, tmp_sc
|
||||||
|
return kp, sc
|
||||||
|
|
||||||
|
|
||||||
|
def _to_openpose_frames(all_keypoints, all_scores, height, width):
|
||||||
|
"""Convert raw keypoint lists to a list of OpenPose-style frame dicts.
|
||||||
|
|
||||||
|
Each frame dict contains:
|
||||||
|
canvas_width, canvas_height, people: list of person dicts with keys:
|
||||||
|
pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels)
|
||||||
|
foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels)
|
||||||
|
face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels)
|
||||||
|
indices 0-67: 68 face landmarks
|
||||||
|
index 68: right eye (body[14])
|
||||||
|
index 69: left eye (body[15])
|
||||||
|
hand_right_keypoints_2d - 21 right-hand kp (absolute pixels)
|
||||||
|
hand_left_keypoints_2d - 21 left-hand kp (absolute pixels)
|
||||||
|
"""
|
||||||
|
def _flatten(kp_slice, sc_slice):
|
||||||
|
return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist()
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for img_idx in range(len(all_keypoints)):
|
||||||
|
people = []
|
||||||
|
for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]):
|
||||||
|
kp, sc = _preprocess_keypoints(kp_raw, sc_raw)
|
||||||
|
# 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15])
|
||||||
|
face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0)
|
||||||
|
face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0)
|
||||||
|
people.append({
|
||||||
|
"pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]),
|
||||||
|
"foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]),
|
||||||
|
"face_keypoints_2d": _flatten(face_kp, face_sc),
|
||||||
|
"hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]),
|
||||||
|
"hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]),
|
||||||
|
})
|
||||||
|
frames.append({"canvas_width": width, "canvas_height": height, "people": people})
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
class KeypointDraw:
|
||||||
|
"""
|
||||||
|
Pose keypoint drawing class that supports both numpy and cv2 backends.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
self.draw = cv2
|
||||||
|
except ImportError:
|
||||||
|
self.draw = self
|
||||||
|
|
||||||
|
# Hand connections (same for both hands)
|
||||||
|
self.hand_edges = [
|
||||||
|
[0, 1], [1, 2], [2, 3], [3, 4], # thumb
|
||||||
|
[0, 5], [5, 6], [6, 7], [7, 8], # index
|
||||||
|
[0, 9], [9, 10], [10, 11], [11, 12], # middle
|
||||||
|
[0, 13], [13, 14], [14, 15], [15, 16], # ring
|
||||||
|
[0, 17], [17, 18], [18, 19], [19, 20], # pinky
|
||||||
|
]
|
||||||
|
|
||||||
|
# Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed)
|
||||||
|
self.body_limbSeq = [
|
||||||
|
[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10],
|
||||||
|
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17],
|
||||||
|
[1, 16], [16, 18]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Colors matching DWPose
|
||||||
|
self.colors = [
|
||||||
|
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
|
||||||
|
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
|
||||||
|
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
|
||||||
|
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def circle(canvas_np, center, radius, color, **kwargs):
|
||||||
|
"""Draw a filled circle using NumPy vectorized operations."""
|
||||||
|
cx, cy = center
|
||||||
|
h, w = canvas_np.shape[:2]
|
||||||
|
|
||||||
|
radius_int = int(np.ceil(radius))
|
||||||
|
|
||||||
|
y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1)
|
||||||
|
x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1)
|
||||||
|
|
||||||
|
if y_max <= y_min or x_max <= x_min:
|
||||||
|
return
|
||||||
|
|
||||||
|
y, x = np.ogrid[y_min:y_max, x_min:x_max]
|
||||||
|
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
|
||||||
|
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs):
|
||||||
|
"""Draw line using Bresenham's algorithm with NumPy operations."""
|
||||||
|
x0, y0, x1, y1 = *pt1, *pt2
|
||||||
|
h, w = canvas_np.shape[:2]
|
||||||
|
dx, dy = abs(x1 - x0), abs(y1 - y0)
|
||||||
|
sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1)
|
||||||
|
err, x, y, line_points = dx - dy, x0, y0, []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
line_points.append((x, y))
|
||||||
|
if x == x1 and y == y1:
|
||||||
|
break
|
||||||
|
e2 = 2 * err
|
||||||
|
if e2 > -dy:
|
||||||
|
err, x = err - dy, x + sx
|
||||||
|
if e2 < dx:
|
||||||
|
err, y = err + dx, y + sy
|
||||||
|
|
||||||
|
if thickness > 1:
|
||||||
|
radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5))
|
||||||
|
for px, py in line_points:
|
||||||
|
y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1)
|
||||||
|
if y_max > y_min and x_max > x_min:
|
||||||
|
yy, xx = np.ogrid[y_min:y_max, x_min:x_max]
|
||||||
|
canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color
|
||||||
|
else:
|
||||||
|
line_points = np.array(line_points)
|
||||||
|
valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w)
|
||||||
|
if (valid_points := line_points[valid]).size:
|
||||||
|
canvas_np[valid_points[:, 1], valid_points[:, 0]] = color
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fillConvexPoly(canvas_np, pts, color, **kwargs):
|
||||||
|
"""Fill polygon using vectorized scanline algorithm."""
|
||||||
|
if len(pts) < 3:
|
||||||
|
return
|
||||||
|
pts = np.array(pts, dtype=np.int32)
|
||||||
|
h, w = canvas_np.shape[:2]
|
||||||
|
y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1)
|
||||||
|
if y_max <= y_min or x_max <= x_min:
|
||||||
|
return
|
||||||
|
yy, xx = np.mgrid[y_min:y_max, x_min:x_max]
|
||||||
|
mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool)
|
||||||
|
|
||||||
|
for i in range(len(pts)):
|
||||||
|
p1, p2 = pts[i], pts[(i + 1) % len(pts)]
|
||||||
|
y1, y2 = p1[1], p2[1]
|
||||||
|
if y1 == y2:
|
||||||
|
continue
|
||||||
|
if y1 > y2:
|
||||||
|
p1, p2, y1, y2 = p2, p1, p2[1], p1[1]
|
||||||
|
if not (edge_mask := (yy >= y1) & (yy < y2)).any():
|
||||||
|
continue
|
||||||
|
mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1))
|
||||||
|
|
||||||
|
canvas_np[y_min:y_max, x_min:x_max][mask] = color
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs):
|
||||||
|
"""Python implementation of cv2.ellipse2Poly."""
|
||||||
|
axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output
|
||||||
|
angle = angle % 360
|
||||||
|
if arc_start > arc_end:
|
||||||
|
arc_start, arc_end = arc_end, arc_start
|
||||||
|
while arc_start < 0:
|
||||||
|
arc_start, arc_end = arc_start + 360, arc_end + 360
|
||||||
|
while arc_end > 360:
|
||||||
|
arc_end, arc_start = arc_end - 360, arc_start - 360
|
||||||
|
if arc_end - arc_start > 360:
|
||||||
|
arc_start, arc_end = 0, 360
|
||||||
|
|
||||||
|
angle_rad = math.radians(angle)
|
||||||
|
alpha, beta = math.cos(angle_rad), math.sin(angle_rad)
|
||||||
|
pts = []
|
||||||
|
for i in range(arc_start, arc_end + delta, delta):
|
||||||
|
theta_rad = math.radians(min(i, arc_end))
|
||||||
|
x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad)
|
||||||
|
pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))])
|
||||||
|
|
||||||
|
unique_pts, prev_pt = [], (float('inf'), float('inf'))
|
||||||
|
for pt in pts:
|
||||||
|
if (pt_tuple := tuple(pt)) != prev_pt:
|
||||||
|
unique_pts.append(pt)
|
||||||
|
prev_pt = pt_tuple
|
||||||
|
|
||||||
|
return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]]
|
||||||
|
|
||||||
|
def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3,
|
||||||
|
draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3):
|
||||||
|
"""
|
||||||
|
Draw wholebody keypoints (134 keypoints after processing) in DWPose style.
|
||||||
|
|
||||||
|
Expected keypoint format (after neck insertion and remapping):
|
||||||
|
- Body: 0-17 (18 keypoints in OpenPose format, neck at index 1)
|
||||||
|
- Foot: 18-23 (6 keypoints)
|
||||||
|
- Face: 24-91 (68 landmarks)
|
||||||
|
- Right hand: 92-112 (21 keypoints)
|
||||||
|
- Left hand: 113-133 (21 keypoints)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
canvas: The canvas to draw on (numpy array)
|
||||||
|
keypoints: Array of keypoint coordinates
|
||||||
|
scores: Optional confidence scores for each keypoint
|
||||||
|
threshold: Minimum confidence threshold for drawing keypoints
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
canvas: The canvas with keypoints drawn
|
||||||
|
"""
|
||||||
|
H, W, C = canvas.shape
|
||||||
|
|
||||||
|
# Draw body limbs
|
||||||
|
if draw_body and len(keypoints) >= 18:
|
||||||
|
for i, limb in enumerate(self.body_limbSeq):
|
||||||
|
# Convert from 1-indexed to 0-indexed
|
||||||
|
idx1, idx2 = limb[0] - 1, limb[1] - 1
|
||||||
|
|
||||||
|
if idx1 >= 18 or idx2 >= 18:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if scores is not None:
|
||||||
|
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
Y = [keypoints[idx1][0], keypoints[idx2][0]]
|
||||||
|
X = [keypoints[idx1][1], keypoints[idx2][1]]
|
||||||
|
mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2
|
||||||
|
length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2)
|
||||||
|
|
||||||
|
if length < 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||||
|
|
||||||
|
polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1)
|
||||||
|
|
||||||
|
self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)])
|
||||||
|
|
||||||
|
# Draw body keypoints
|
||||||
|
if draw_body and len(keypoints) >= 18:
|
||||||
|
for i in range(18):
|
||||||
|
if scores is not None and scores[i] < threshold:
|
||||||
|
continue
|
||||||
|
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||||
|
if 0 <= x < W and 0 <= y < H:
|
||||||
|
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
|
||||||
|
|
||||||
|
# Draw foot keypoints (18-23, 6 keypoints)
|
||||||
|
if draw_feet and len(keypoints) >= 24:
|
||||||
|
for i in range(18, 24):
|
||||||
|
if scores is not None and scores[i] < threshold:
|
||||||
|
continue
|
||||||
|
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||||
|
if 0 <= x < W and 0 <= y < H:
|
||||||
|
self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1)
|
||||||
|
|
||||||
|
# Draw right hand (92-112)
|
||||||
|
if draw_hands and len(keypoints) >= 113:
|
||||||
|
eps = 0.01
|
||||||
|
for ie, edge in enumerate(self.hand_edges):
|
||||||
|
idx1, idx2 = 92 + edge[0], 92 + edge[1]
|
||||||
|
if scores is not None:
|
||||||
|
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||||
|
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||||
|
|
||||||
|
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||||
|
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||||
|
# HSV to RGB conversion for rainbow colors
|
||||||
|
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||||
|
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||||
|
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
|
||||||
|
|
||||||
|
# Draw right hand keypoints
|
||||||
|
for i in range(92, 113):
|
||||||
|
if scores is not None and scores[i] < threshold:
|
||||||
|
continue
|
||||||
|
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||||
|
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||||
|
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||||
|
|
||||||
|
# Draw left hand (113-133)
|
||||||
|
if draw_hands and len(keypoints) >= 134:
|
||||||
|
eps = 0.01
|
||||||
|
for ie, edge in enumerate(self.hand_edges):
|
||||||
|
idx1, idx2 = 113 + edge[0], 113 + edge[1]
|
||||||
|
if scores is not None:
|
||||||
|
if scores[idx1] < threshold or scores[idx2] < threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1])
|
||||||
|
x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1])
|
||||||
|
|
||||||
|
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||||
|
if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H:
|
||||||
|
# HSV to RGB conversion for rainbow colors
|
||||||
|
r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0)
|
||||||
|
color = (int(r * 255), int(g * 255), int(b * 255))
|
||||||
|
self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2)
|
||||||
|
|
||||||
|
# Draw left hand keypoints
|
||||||
|
for i in range(113, 134):
|
||||||
|
if scores is not None and i < len(scores) and scores[i] < threshold:
|
||||||
|
continue
|
||||||
|
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||||
|
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||||
|
self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||||
|
|
||||||
|
# Draw face keypoints (24-91) - white dots only, no lines
|
||||||
|
if draw_face and len(keypoints) >= 92:
|
||||||
|
eps = 0.01
|
||||||
|
for i in range(24, 92):
|
||||||
|
if scores is not None and scores[i] < threshold:
|
||||||
|
continue
|
||||||
|
x, y = int(keypoints[i][0]), int(keypoints[i][1])
|
||||||
|
if x > eps and y > eps and 0 <= x < W and 0 <= y < H:
|
||||||
|
self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1)
|
||||||
|
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
class SDPoseDrawKeypoints(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SDPoseDrawKeypoints",
|
||||||
|
category="image/preprocessors",
|
||||||
|
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"],
|
||||||
|
inputs=[
|
||||||
|
io.Custom("POSE_KEYPOINT").Input("keypoints"),
|
||||||
|
io.Boolean.Input("draw_body", default=True),
|
||||||
|
io.Boolean.Input("draw_hands", default=True),
|
||||||
|
io.Boolean.Input("draw_face", default=True),
|
||||||
|
io.Boolean.Input("draw_feet", default=False),
|
||||||
|
io.Int.Input("stick_width", default=4, min=1, max=10, step=1),
|
||||||
|
io.Int.Input("face_point_size", default=3, min=1, max=10, step=1),
|
||||||
|
io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput:
|
||||||
|
if not keypoints:
|
||||||
|
return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32))
|
||||||
|
height = keypoints[0]["canvas_height"]
|
||||||
|
width = keypoints[0]["canvas_width"]
|
||||||
|
|
||||||
|
def _parse(flat, n):
|
||||||
|
arr = np.array(flat, dtype=np.float32).reshape(n, 3)
|
||||||
|
return arr[:, :2], arr[:, 2]
|
||||||
|
|
||||||
|
def _zeros(n):
|
||||||
|
return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32)
|
||||||
|
|
||||||
|
pose_outputs = []
|
||||||
|
drawer = KeypointDraw()
|
||||||
|
|
||||||
|
for frame in tqdm(keypoints, desc="Drawing keypoints on frames"):
|
||||||
|
canvas = np.zeros((height, width, 3), dtype=np.uint8)
|
||||||
|
for person in frame["people"]:
|
||||||
|
body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18)
|
||||||
|
foot_raw = person.get("foot_keypoints_2d")
|
||||||
|
foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6)
|
||||||
|
face_kp, face_sc = _parse(person["face_keypoints_2d"], 70)
|
||||||
|
face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them
|
||||||
|
rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21)
|
||||||
|
lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21)
|
||||||
|
|
||||||
|
kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0)
|
||||||
|
sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0)
|
||||||
|
|
||||||
|
canvas = drawer.draw_wholebody_keypoints(
|
||||||
|
canvas, kp, sc,
|
||||||
|
threshold=score_threshold,
|
||||||
|
draw_body=draw_body, draw_feet=draw_feet,
|
||||||
|
draw_face=draw_face, draw_hands=draw_hands,
|
||||||
|
stick_width=stick_width, face_point_size=face_point_size,
|
||||||
|
)
|
||||||
|
pose_outputs.append(canvas)
|
||||||
|
|
||||||
|
pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0)
|
||||||
|
final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0
|
||||||
|
return io.NodeOutput(final_pose_output)
|
||||||
|
|
||||||
|
class SDPoseKeypointExtractor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SDPoseKeypointExtractor",
|
||||||
|
category="image/preprocessors",
|
||||||
|
search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"],
|
||||||
|
description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Int.Input("batch_size", default=16, min=1, max=10000, step=1),
|
||||||
|
io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput:
|
||||||
|
|
||||||
|
height, width = image.shape[-3], image.shape[-2]
|
||||||
|
context = LotusConditioning().execute().result[0]
|
||||||
|
|
||||||
|
# Use output_block_patch to capture the last 640-channel feature
|
||||||
|
def output_patch(h, hsp, transformer_options):
|
||||||
|
nonlocal captured_feat
|
||||||
|
if h.shape[1] == 640: # Capture the features for wholebody
|
||||||
|
captured_feat = h.clone()
|
||||||
|
return h, hsp
|
||||||
|
|
||||||
|
model_clone = model.clone()
|
||||||
|
model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}}
|
||||||
|
|
||||||
|
if not hasattr(model.model.diffusion_model, 'heatmap_head'):
|
||||||
|
raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.")
|
||||||
|
|
||||||
|
head = model.model.diffusion_model.heatmap_head
|
||||||
|
total_images = image.shape[0]
|
||||||
|
captured_feat = None
|
||||||
|
|
||||||
|
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||||
|
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||||
|
|
||||||
|
def _run_on_latent(latent_batch):
|
||||||
|
"""Run one forward pass and return (keypoints_list, scores_list) for the batch."""
|
||||||
|
nonlocal captured_feat
|
||||||
|
captured_feat = None
|
||||||
|
_ = comfy.sample.sample(
|
||||||
|
model_clone,
|
||||||
|
noise=torch.zeros_like(latent_batch),
|
||||||
|
steps=1, cfg=1.0,
|
||||||
|
sampler_name="euler", scheduler="simple",
|
||||||
|
positive=context, negative=context,
|
||||||
|
latent_image=latent_batch, disable_noise=True, disable_pbar=True,
|
||||||
|
)
|
||||||
|
return head(captured_feat) # keypoints_batch, scores_batch
|
||||||
|
|
||||||
|
# all_keypoints / all_scores are lists-of-lists:
|
||||||
|
# outer index = input image index
|
||||||
|
# inner index = detected person (one per bbox, or one for full-image)
|
||||||
|
all_keypoints = [] # shape: [n_images][n_persons]
|
||||||
|
all_scores = [] # shape: [n_images][n_persons]
|
||||||
|
pbar = comfy.utils.ProgressBar(total_images)
|
||||||
|
|
||||||
|
if bboxes is not None:
|
||||||
|
if not isinstance(bboxes, list):
|
||||||
|
bboxes = [[bboxes]]
|
||||||
|
elif len(bboxes) == 0:
|
||||||
|
bboxes = [None] * total_images
|
||||||
|
# --- bbox-crop mode: one forward pass per crop -------------------------
|
||||||
|
for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"):
|
||||||
|
img = image[img_idx:img_idx + 1] # (1, H, W, C)
|
||||||
|
# Broadcasting: if fewer bbox lists than images, repeat the last one.
|
||||||
|
img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None
|
||||||
|
|
||||||
|
img_keypoints = []
|
||||||
|
img_scores = []
|
||||||
|
|
||||||
|
if img_bboxes:
|
||||||
|
for bbox in img_bboxes:
|
||||||
|
x1 = max(0, int(bbox["x"]))
|
||||||
|
y1 = max(0, int(bbox["y"]))
|
||||||
|
x2 = min(width, int(bbox["x"] + bbox["width"]))
|
||||||
|
y2 = min(height, int(bbox["y"] + bbox["height"]))
|
||||||
|
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop_h_px, crop_w_px = y2 - y1, x2 - x1
|
||||||
|
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||||
|
|
||||||
|
# scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size.
|
||||||
|
scale = min(model_h / crop_h_px, model_w / crop_w_px)
|
||||||
|
scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale))
|
||||||
|
pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2
|
||||||
|
|
||||||
|
crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW
|
||||||
|
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||||
|
padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||||
|
padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||||
|
crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC
|
||||||
|
|
||||||
|
latent_crop = vae.encode(crop_resized)
|
||||||
|
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||||
|
kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space
|
||||||
|
|
||||||
|
# remove padding offset, undo scale, offset to full-image coordinates.
|
||||||
|
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||||
|
kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1
|
||||||
|
kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1
|
||||||
|
|
||||||
|
img_keypoints.append(kp)
|
||||||
|
img_scores.append(sc)
|
||||||
|
else:
|
||||||
|
# No bboxes for this image – run on the full image
|
||||||
|
latent_img = vae.encode(img)
|
||||||
|
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||||
|
img_keypoints.append(kp_batch[0])
|
||||||
|
img_scores.append(sc_batch[0])
|
||||||
|
|
||||||
|
all_keypoints.append(img_keypoints)
|
||||||
|
all_scores.append(img_scores)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
else: # full-image mode, batched
|
||||||
|
tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints")
|
||||||
|
for batch_start in range(0, total_images, batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, total_images)
|
||||||
|
latent_batch = vae.encode(image[batch_start:batch_end])
|
||||||
|
|
||||||
|
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||||
|
|
||||||
|
for kp, sc in zip(kp_batch, sc_batch):
|
||||||
|
all_keypoints.append([kp])
|
||||||
|
all_scores.append([sc])
|
||||||
|
tqdm_pbar.update(1)
|
||||||
|
|
||||||
|
pbar.update(batch_end - batch_start)
|
||||||
|
|
||||||
|
openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width)
|
||||||
|
return io.NodeOutput(openpose_frames)
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_bboxes(kp2ds, scale, image_shape):
|
||||||
|
h, w = image_shape
|
||||||
|
kp2ds_face = kp2ds.copy()[1:] * (w, h)
|
||||||
|
|
||||||
|
min_x, min_y = np.min(kp2ds_face, axis=0)
|
||||||
|
max_x, max_y = np.max(kp2ds_face, axis=0)
|
||||||
|
|
||||||
|
initial_width = max_x - min_x
|
||||||
|
initial_height = max_y - min_y
|
||||||
|
|
||||||
|
if initial_width <= 0 or initial_height <= 0:
|
||||||
|
return [0, 0, 0, 0]
|
||||||
|
|
||||||
|
initial_area = initial_width * initial_height
|
||||||
|
|
||||||
|
expanded_area = initial_area * scale
|
||||||
|
|
||||||
|
new_width = np.sqrt(expanded_area * (initial_width / initial_height))
|
||||||
|
new_height = np.sqrt(expanded_area * (initial_height / initial_width))
|
||||||
|
|
||||||
|
delta_width = (new_width - initial_width) / 2
|
||||||
|
delta_height = (new_height - initial_height) / 4
|
||||||
|
|
||||||
|
expanded_min_x = max(min_x - delta_width, 0)
|
||||||
|
expanded_max_x = min(max_x + delta_width, w)
|
||||||
|
expanded_min_y = max(min_y - 3 * delta_height, 0)
|
||||||
|
expanded_max_y = min(max_y + delta_height, h)
|
||||||
|
|
||||||
|
return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)]
|
||||||
|
|
||||||
|
class SDPoseFaceBBoxes(io.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SDPoseFaceBBoxes",
|
||||||
|
category="image/preprocessors",
|
||||||
|
search_aliases=["face bbox", "face bounding box", "pose", "keypoints"],
|
||||||
|
inputs=[
|
||||||
|
io.Custom("POSE_KEYPOINT").Input("keypoints"),
|
||||||
|
io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."),
|
||||||
|
io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, keypoints, scale, force_square) -> io.NodeOutput:
|
||||||
|
all_bboxes = []
|
||||||
|
for frame in keypoints:
|
||||||
|
h = frame["canvas_height"]
|
||||||
|
w = frame["canvas_width"]
|
||||||
|
frame_bboxes = []
|
||||||
|
for person in frame["people"]:
|
||||||
|
face_flat = person.get("face_keypoints_2d", [])
|
||||||
|
if not face_flat:
|
||||||
|
continue
|
||||||
|
# Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye)
|
||||||
|
face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3)
|
||||||
|
face_xy = face_arr[:, :2] # (70, 2) in absolute pixels
|
||||||
|
|
||||||
|
kp_norm = face_xy / np.array([w, h], dtype=np.float32)
|
||||||
|
kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2)
|
||||||
|
|
||||||
|
x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w))
|
||||||
|
if x2 > x1 and y2 > y1:
|
||||||
|
if force_square:
|
||||||
|
bw, bh = x2 - x1, y2 - y1
|
||||||
|
if bw != bh:
|
||||||
|
side = max(bw, bh)
|
||||||
|
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||||
|
half = side // 2
|
||||||
|
x1 = max(0, cx - half)
|
||||||
|
y1 = max(0, cy - half)
|
||||||
|
x2 = min(w, x1 + side)
|
||||||
|
y2 = min(h, y1 + side)
|
||||||
|
# Re-anchor if clamped
|
||||||
|
x1 = max(0, x2 - side)
|
||||||
|
y1 = max(0, y2 - side)
|
||||||
|
frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1})
|
||||||
|
|
||||||
|
all_bboxes.append(frame_bboxes)
|
||||||
|
|
||||||
|
return io.NodeOutput(all_bboxes)
|
||||||
|
|
||||||
|
|
||||||
|
class CropByBBoxes(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CropByBBoxes",
|
||||||
|
category="image/preprocessors",
|
||||||
|
search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"],
|
||||||
|
description="Crop and resize regions from the input image batch based on provided bounding boxes.",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.BoundingBox.Input("bboxes", force_input=True),
|
||||||
|
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
||||||
|
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
||||||
|
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
|
||||||
|
total_frames = image.shape[0]
|
||||||
|
img_h = image.shape[1]
|
||||||
|
img_w = image.shape[2]
|
||||||
|
num_ch = image.shape[3]
|
||||||
|
|
||||||
|
if not isinstance(bboxes, list):
|
||||||
|
bboxes = [[bboxes]]
|
||||||
|
elif len(bboxes) == 0:
|
||||||
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
crops = []
|
||||||
|
|
||||||
|
for frame_idx in range(total_frames):
|
||||||
|
frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)]
|
||||||
|
if not frame_bboxes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W)
|
||||||
|
|
||||||
|
# Union all bboxes for this frame into a single crop region
|
||||||
|
x1 = min(int(b["x"]) for b in frame_bboxes)
|
||||||
|
y1 = min(int(b["y"]) for b in frame_bboxes)
|
||||||
|
x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes)
|
||||||
|
y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes)
|
||||||
|
|
||||||
|
if padding > 0:
|
||||||
|
x1 = max(0, x1 - padding)
|
||||||
|
y1 = max(0, y1 - padding)
|
||||||
|
x2 = min(img_w, x2 + padding)
|
||||||
|
y2 = min(img_h, y2 + padding)
|
||||||
|
|
||||||
|
x1, x2 = max(0, x1), min(img_w, x2)
|
||||||
|
y1, y2 = max(0, y1), min(img_h, y2)
|
||||||
|
|
||||||
|
# Fallback for empty/degenerate crops
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
fallback_size = int(min(img_h, img_w) * 0.3)
|
||||||
|
fb_x1 = max(0, (img_w - fallback_size) // 2)
|
||||||
|
fb_y1 = max(0, int(img_h * 0.1))
|
||||||
|
fb_x2 = min(img_w, fb_x1 + fallback_size)
|
||||||
|
fb_y2 = min(img_h, fb_y1 + fallback_size)
|
||||||
|
if fb_x2 <= fb_x1 or fb_y2 <= fb_y1:
|
||||||
|
crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device))
|
||||||
|
continue
|
||||||
|
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
||||||
|
|
||||||
|
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
||||||
|
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||||
|
crops.append(resized)
|
||||||
|
|
||||||
|
if not crops:
|
||||||
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C)
|
||||||
|
return io.NodeOutput(out_images)
|
||||||
|
|
||||||
|
|
||||||
|
class SDPoseExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SDPoseKeypointExtractor,
|
||||||
|
SDPoseDrawKeypoints,
|
||||||
|
SDPoseFaceBBoxes,
|
||||||
|
CropByBBoxes,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SDPoseExtension:
|
||||||
|
return SDPoseExtension()
|
||||||
@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||||
display_name="Get Video Components",
|
display_name="Get Video Components",
|
||||||
category="image/video",
|
category="image/video",
|
||||||
essentials_category="Video Tools",
|
|
||||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||||
@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
"start time",
|
"start time",
|
||||||
],
|
],
|
||||||
category="image/video",
|
category="image/video",
|
||||||
|
essentials_category="Video Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video"),
|
io.Video.Input("video"),
|
||||||
io.Float.Input(
|
io.Float.Input(
|
||||||
|
|||||||
@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||||
|
|
||||||
|
|
||||||
|
class WanSCAILToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanSCAILToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("reference_image", optional=True),
|
||||||
|
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||||
|
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||||
|
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||||
|
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
ref_latent = None
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
|
||||||
|
if ref_latent is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
if pose_video is not None:
|
||||||
|
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||||
|
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||||
|
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanAnimateToVideo,
|
WanAnimateToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
WanInfiniteTalkToVideo,
|
WanInfiniteTalkToVideo,
|
||||||
|
WanSCAILToVideo,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.15.0"
|
__version__ = "0.15.1"
|
||||||
|
|||||||
10
main.py
10
main.py
@ -16,6 +16,10 @@ from comfy_execution.progress import get_progress_state
|
|||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
|
import comfy_aimdo.control
|
||||||
|
|
||||||
|
if enables_dynamic_vram():
|
||||||
|
comfy_aimdo.control.init()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
@ -173,10 +177,6 @@ import gc
|
|||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||||
|
|
||||||
import comfy_aimdo.control
|
|
||||||
|
|
||||||
if enables_dynamic_vram():
|
|
||||||
comfy_aimdo.control.init()
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ import hook_breaker_ac10a0
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
if enables_dynamic_vram():
|
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
|
|||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
|
||||||
|
"""
|
||||||
|
Apply values to conditioning only during [start_percent, end_percent], keeping the
|
||||||
|
original conditioning active outside that range. Respects existing per-entry ranges.
|
||||||
|
"""
|
||||||
|
if start_percent > end_percent:
|
||||||
|
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
cond_start = t[1].get("start_percent", 0.0)
|
||||||
|
cond_end = t[1].get("end_percent", 1.0)
|
||||||
|
intersect_start = max(start_percent, cond_start)
|
||||||
|
intersect_end = min(end_percent, cond_end)
|
||||||
|
|
||||||
|
if intersect_start >= intersect_end: # no overlap: emit unchanged
|
||||||
|
c.append(t)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if intersect_start > cond_start: # part before the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
|
||||||
|
|
||||||
|
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
|
||||||
|
|
||||||
|
if intersect_end < cond_end: # part after the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
|
||||||
|
return c
|
||||||
|
|
||||||
def pillow(fn, arg):
|
def pillow(fn, arg):
|
||||||
prev_value = None
|
prev_value = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
5
nodes.py
5
nodes.py
@ -976,7 +976,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -1925,7 +1925,6 @@ class ImageInvert:
|
|||||||
|
|
||||||
class ImageBatch:
|
class ImageBatch:
|
||||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||||
ESSENTIALS_CATEGORY = "Image Tools"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -2436,6 +2435,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
|
"nodes_resolution.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
"nodes_wanmove.py",
|
||||||
@ -2448,6 +2448,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_toolkit.py",
|
"nodes_toolkit.py",
|
||||||
"nodes_replacements.py",
|
"nodes_replacements.py",
|
||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
|
"nodes_sdpose.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.15.0"
|
version = "0.15.1"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.39.16
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.3
|
comfyui-workflow-templates==0.9.4
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.1
|
comfy-aimdo>=0.2.4
|
||||||
requests
|
requests
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
@ -31,5 +31,4 @@ spandrel
|
|||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
PyOpenGL
|
PyOpenGL
|
||||||
PyOpenGL-accelerate
|
|
||||||
glfw
|
glfw
|
||||||
|
|||||||
@ -49,6 +49,12 @@ def mock_provider(mock_releases):
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
import utils.install_util
|
||||||
|
utils.install_util.PACKAGE_VERSIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def test_get_release(mock_provider, mock_releases):
|
def test_get_release(mock_provider, mock_releases):
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
release = mock_provider.get_release(version)
|
release = mock_provider.get_release(version)
|
||||||
|
|||||||
112
tests-unit/comfy_test/model_detection_test.py
Normal file
112
tests-unit/comfy_test/model_detection_test.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
|
||||||
|
import comfy.supported_models
|
||||||
|
|
||||||
|
|
||||||
|
def _make_longcat_comfyui_sd():
|
||||||
|
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
|
||||||
|
sd = {}
|
||||||
|
H = 32 # Reduce hidden state dimension to reduce memory usage
|
||||||
|
C_IN = 16
|
||||||
|
C_CTX = 3584
|
||||||
|
|
||||||
|
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
|
||||||
|
sd["img_in.bias"] = torch.empty(H)
|
||||||
|
sd["txt_in.weight"] = torch.empty(H, C_CTX)
|
||||||
|
sd["txt_in.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["time_in.in_layer.weight"] = torch.empty(H, 256)
|
||||||
|
sd["time_in.in_layer.bias"] = torch.empty(H)
|
||||||
|
sd["time_in.out_layer.weight"] = torch.empty(H, H)
|
||||||
|
sd["time_in.out_layer.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H)
|
||||||
|
sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H)
|
||||||
|
sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H)
|
||||||
|
sd["final_layer.linear.bias"] = torch.empty(C_IN * 4)
|
||||||
|
|
||||||
|
for i in range(19):
|
||||||
|
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H)
|
||||||
|
sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H)
|
||||||
|
for i in range(38):
|
||||||
|
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _make_flux_schnell_comfyui_sd():
|
||||||
|
"""Minimal ComfyUI-format state dict for standard Flux Schnell."""
|
||||||
|
sd = {}
|
||||||
|
H = 32 # Reduce hidden state dimension to reduce memory usage
|
||||||
|
C_IN = 16
|
||||||
|
|
||||||
|
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
|
||||||
|
sd["img_in.bias"] = torch.empty(H)
|
||||||
|
sd["txt_in.weight"] = torch.empty(H, 4096)
|
||||||
|
sd["txt_in.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H)
|
||||||
|
sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
for i in range(19):
|
||||||
|
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
for i in range(38):
|
||||||
|
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelDetection:
|
||||||
|
"""Verify that first-match model detection selects the correct model
|
||||||
|
based on list ordering and unet_config specificity."""
|
||||||
|
|
||||||
|
def test_longcat_before_schnell_in_models_list(self):
|
||||||
|
"""LongCatImage must appear before FluxSchnell in the models list."""
|
||||||
|
models = comfy.supported_models.models
|
||||||
|
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
|
||||||
|
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
|
||||||
|
assert longcat_idx < schnell_idx, (
|
||||||
|
f"LongCatImage (index {longcat_idx}) must come before "
|
||||||
|
f"FluxSchnell (index {schnell_idx}) in the models list"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_longcat_comfyui_detected_as_longcat(self):
|
||||||
|
sd = _make_longcat_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "flux"
|
||||||
|
assert unet_config["context_in_dim"] == 3584
|
||||||
|
assert unet_config["vec_in_dim"] is None
|
||||||
|
assert unet_config["guidance_embed"] is False
|
||||||
|
assert unet_config["txt_ids_dims"] == [1, 2]
|
||||||
|
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
assert model_config is not None
|
||||||
|
assert type(model_config).__name__ == "LongCatImage"
|
||||||
|
|
||||||
|
def test_longcat_comfyui_keys_pass_through_unchanged(self):
|
||||||
|
"""Pre-converted weights should not be transformed by process_unet_state_dict."""
|
||||||
|
sd = _make_longcat_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
|
||||||
|
processed = model_config.process_unet_state_dict(dict(sd))
|
||||||
|
assert "img_in.weight" in processed
|
||||||
|
assert "txt_in.weight" in processed
|
||||||
|
assert "time_in.in_layer.weight" in processed
|
||||||
|
assert "final_layer.linear.weight" in processed
|
||||||
|
|
||||||
|
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
|
||||||
|
sd = _make_flux_schnell_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "flux"
|
||||||
|
assert unet_config["context_in_dim"] == 4096
|
||||||
|
assert unet_config["txt_ids_dims"] == []
|
||||||
|
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
assert model_config is not None
|
||||||
|
assert type(model_config).__name__ == "FluxSchnell"
|
||||||
@ -38,13 +38,13 @@ class TestIsPreviewable:
|
|||||||
"""Unit tests for is_previewable()"""
|
"""Unit tests for is_previewable()"""
|
||||||
|
|
||||||
def test_previewable_media_types(self):
|
def test_previewable_media_types(self):
|
||||||
"""Images, video, audio, 3d media types should be previewable."""
|
"""Images, video, audio, 3d, text media types should be previewable."""
|
||||||
for media_type in ['images', 'video', 'audio', '3d']:
|
for media_type in ['images', 'video', 'audio', '3d', 'text']:
|
||||||
assert is_previewable(media_type, {}) is True
|
assert is_previewable(media_type, {}) is True
|
||||||
|
|
||||||
def test_non_previewable_media_types(self):
|
def test_non_previewable_media_types(self):
|
||||||
"""Other media types should not be previewable."""
|
"""Other media types should not be previewable."""
|
||||||
for media_type in ['latents', 'text', 'metadata', 'files']:
|
for media_type in ['latents', 'metadata', 'files']:
|
||||||
assert is_previewable(media_type, {}) is False
|
assert is_previewable(media_type, {}) is False
|
||||||
|
|
||||||
def test_3d_extensions_previewable(self):
|
def test_3d_extensions_previewable(self):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
# The path to the requirements.txt file
|
# The path to the requirements.txt file
|
||||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
|
|||||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_version(version: str) -> bool:
|
||||||
|
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||||
|
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||||
|
return bool(re.match(pattern, version))
|
||||||
|
|
||||||
|
|
||||||
|
PACKAGE_VERSIONS = {}
|
||||||
|
def get_required_packages_versions():
|
||||||
|
if len(PACKAGE_VERSIONS) > 0:
|
||||||
|
return PACKAGE_VERSIONS.copy()
|
||||||
|
out = PACKAGE_VERSIONS
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip().replace(">=", "==")
|
||||||
|
s = line.split("==")
|
||||||
|
if len(s) == 2:
|
||||||
|
version_str = s[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||||
|
continue
|
||||||
|
out[s[0]] = version_str
|
||||||
|
return out.copy()
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user