mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 08:16:47 +08:00
Merge branch 'master' into breifnet
This commit is contained in:
commit
bf1c4ed745
@ -17,7 +17,7 @@ from importlib.metadata import version
|
|||||||
import requests
|
import requests
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
from utils.install_util import get_missing_requirements_message, requirements_path
|
from utils.install_util import get_missing_requirements_message, get_required_packages_versions
|
||||||
|
|
||||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||||
import app.logger
|
import app.logger
|
||||||
@ -45,25 +45,7 @@ def get_installed_frontend_version():
|
|||||||
|
|
||||||
|
|
||||||
def get_required_frontend_version():
|
def get_required_frontend_version():
|
||||||
"""Get the required frontend version from requirements.txt."""
|
return get_required_packages_versions().get("comfyui-frontend-package", None)
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-frontend-package=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-frontend-package not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required frontend version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
@ -217,25 +199,7 @@ class FrontendManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_required_templates_version(cls) -> str:
|
def get_required_templates_version(cls) -> str:
|
||||||
"""Get the required workflow templates version from requirements.txt."""
|
return get_required_packages_versions().get("comfyui-workflow-templates", None)
|
||||||
try:
|
|
||||||
with open(requirements_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if line.startswith("comfyui-workflow-templates=="):
|
|
||||||
version_str = line.split("==")[-1]
|
|
||||||
if not is_valid_version(version_str):
|
|
||||||
logging.error(f"Invalid templates version format in requirements.txt: {version_str}")
|
|
||||||
return None
|
|
||||||
return version_str
|
|
||||||
logging.error("comfyui-workflow-templates not found in requirements.txt")
|
|
||||||
return None
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error("requirements.txt not found. Cannot determine required templates version.")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error reading requirements.txt: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frontend_path(cls) -> str:
|
def default_frontend_path(cls) -> str:
|
||||||
|
|||||||
@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
|
|
||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
|
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
DynamicVRAM = "dynamic_vram"
|
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
@ -260,4 +260,4 @@ else:
|
|||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
def enables_dynamic_vram():
|
def enables_dynamic_vram():
|
||||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||||
|
|||||||
@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
return # substep from multi-step sampler: keep self._step from the last full step
|
||||||
self._step = int(matches[0].item())
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
|
|||||||
@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
|
|||||||
|
|
||||||
def process_out(self, latent):
|
def process_out(self, latent):
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
class ZImagePixelSpace(ChromaRadiance):
|
||||||
|
"""Pixel-space latent format for ZImage DCT variant.
|
||||||
|
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
|
|||||||
from comfy.ldm.flux.math import apply_rope
|
from comfy.ldm.flux.math import apply_rope
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
||||||
|
|
||||||
|
|
||||||
def invert_slices(slices, length):
|
def invert_slices(slices, length):
|
||||||
@ -858,3 +859,267 @@ class NextDiT(nn.Module):
|
|||||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||||
return -img
|
return -img
|
||||||
|
|
||||||
|
|
||||||
|
#############################################################################
|
||||||
|
# Pixel Space Decoder Components #
|
||||||
|
#############################################################################
|
||||||
|
|
||||||
|
def _modulate_shift_scale(x, shift, scale):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
|
||||||
|
class PixelResBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||||
|
an identity at the beginning of training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||||
|
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
||||||
|
h = self.mlp(h)
|
||||||
|
return x + gate * h
|
||||||
|
|
||||||
|
|
||||||
|
class DCTFinalLayer(nn.Module):
|
||||||
|
"""Zero-initialised output projection (adopted from DiT)."""
|
||||||
|
|
||||||
|
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.linear(self.norm_final(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleMLPAdaLN(nn.Module):
|
||||||
|
"""
|
||||||
|
Small MLP decoder head for the pixel-space variant.
|
||||||
|
|
||||||
|
Takes per-patch pixel values and a per-patch conditioning vector from the
|
||||||
|
transformer backbone and predicts the denoised pixel values.
|
||||||
|
|
||||||
|
x : [B*N, P^2, C] – noisy pixel values per patch position
|
||||||
|
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
||||||
|
→ [B*N, P^2, C]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
model_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
z_channels: int,
|
||||||
|
num_res_blocks: int,
|
||||||
|
max_freqs: int = 8,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# Project backbone hidden state → per-patch conditioning
|
||||||
|
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Input projection with DCT positional encoding
|
||||||
|
self.input_embedder = NerfEmbedder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
hidden_size_input=model_channels,
|
||||||
|
max_freqs=max_freqs,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Residual blocks
|
||||||
|
self.res_blocks = nn.ModuleList([
|
||||||
|
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Output projection
|
||||||
|
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||||
|
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||||
|
original_dtype = x.dtype
|
||||||
|
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||||
|
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||||
|
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||||
|
x = x.to(weight_dtype)
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x, y)
|
||||||
|
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
|
||||||
|
|
||||||
|
|
||||||
|
#############################################################################
|
||||||
|
# NextDiT – Pixel Space #
|
||||||
|
#############################################################################
|
||||||
|
|
||||||
|
class NextDiTPixelSpace(NextDiT):
|
||||||
|
"""
|
||||||
|
Pixel-space variant of NextDiT.
|
||||||
|
|
||||||
|
Identical transformer backbone to NextDiT, but the output head is replaced
|
||||||
|
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
||||||
|
per patch rather than a single affine projection.
|
||||||
|
|
||||||
|
Key differences vs NextDiT:
|
||||||
|
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
||||||
|
• ``_forward`` stores the raw patchified pixel values before the backbone
|
||||||
|
embedding and feeds them to ``dec_net`` together with the per-patch
|
||||||
|
backbone hidden states.
|
||||||
|
• Supports optional x0 prediction via ``use_x0``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# decoder-specific
|
||||||
|
decoder_hidden_size: int = 3840,
|
||||||
|
decoder_num_res_blocks: int = 4,
|
||||||
|
decoder_max_freqs: int = 8,
|
||||||
|
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
|
||||||
|
use_x0: bool = False,
|
||||||
|
# all NextDiT args forwarded unchanged
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Remove the latent-space final layer – not used in pixel space
|
||||||
|
del self.final_layer
|
||||||
|
|
||||||
|
patch_size = kwargs.get("patch_size", 2)
|
||||||
|
in_channels = kwargs.get("in_channels", 4)
|
||||||
|
dim = kwargs.get("dim", 4096)
|
||||||
|
|
||||||
|
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
|
||||||
|
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
|
||||||
|
|
||||||
|
self.dec_net = SimpleMLPAdaLN(
|
||||||
|
in_channels=dec_in_ch,
|
||||||
|
model_channels=decoder_hidden_size,
|
||||||
|
out_channels=dec_in_ch,
|
||||||
|
z_channels=dim,
|
||||||
|
num_res_blocks=decoder_num_res_blocks,
|
||||||
|
max_freqs=decoder_max_freqs,
|
||||||
|
dtype=kwargs.get("dtype"),
|
||||||
|
device=kwargs.get("device"),
|
||||||
|
operations=kwargs.get("operations"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_x0:
|
||||||
|
self.register_buffer("__x0__", torch.tensor([]))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
|
||||||
|
# with the pixel-space dec_net decoder.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||||
|
omni = len(ref_latents) > 0
|
||||||
|
if omni:
|
||||||
|
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||||
|
|
||||||
|
t = 1.0 - timesteps
|
||||||
|
cap_feats = context
|
||||||
|
cap_mask = attention_mask
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
|
|
||||||
|
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
|
||||||
|
adaln_input = t
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
|
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
|
||||||
|
pH = pW = self.patch_size
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
pixel_patches = (
|
||||||
|
x.view(B, C, H // pH, pH, W // pW, pW)
|
||||||
|
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
||||||
|
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
||||||
|
.flatten(1, 2) # [B, N, pH*pW*C]
|
||||||
|
)
|
||||||
|
N = pixel_patches.shape[1]
|
||||||
|
# decoder sees one token per patch: [B*N, 1, P^2*C]
|
||||||
|
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
|
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
|
||||||
|
x, cap_feats, cap_mask, adaln_input, num_tokens,
|
||||||
|
ref_latents=ref_latents, ref_contexts=ref_contexts,
|
||||||
|
siglip_feats=siglip_feats, transformer_options=transformer_options
|
||||||
|
)
|
||||||
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.layers)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
img_input = img
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
|
if "img" in out:
|
||||||
|
img[:, cap_size[0]:] = out["img"]
|
||||||
|
if "txt" in out:
|
||||||
|
img[:, :cap_size[0]] = out["txt"]
|
||||||
|
|
||||||
|
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
|
||||||
|
# img may have padding tokens beyond N; only the first N are real image patches
|
||||||
|
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
|
||||||
|
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
|
||||||
|
|
||||||
|
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
|
||||||
|
output = output.reshape(B, N, -1) # [B, N, P^2*C]
|
||||||
|
|
||||||
|
# prepend zero cap placeholder so unpatchify indexing works unchanged
|
||||||
|
cap_placeholder = torch.zeros(
|
||||||
|
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
||||||
|
)
|
||||||
|
img_out = self.unpatchify(
|
||||||
|
torch.cat([cap_placeholder, output], dim=1),
|
||||||
|
img_size, cap_size, return_tensor=x_is_tensor
|
||||||
|
)[:, :, :h, :w]
|
||||||
|
|
||||||
|
return -img_out
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||||
|
# _forward returns neg_x0 = -x0 (negated decoder output).
|
||||||
|
#
|
||||||
|
# Reference inference (working_inference_reference.py):
|
||||||
|
# out = _forward(img, t) # = -x0
|
||||||
|
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
|
||||||
|
# img += (t_prev - t_curr) * pred # Euler step
|
||||||
|
#
|
||||||
|
# ComfyUI's Euler sampler does the same:
|
||||||
|
# x_next = x + (sigma_next - sigma) * model_output
|
||||||
|
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
|
||||||
|
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
|
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||||
|
|
||||||
|
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
||||||
|
|||||||
@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel):
|
|||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class SCAILWanModel(WanModel):
|
||||||
|
def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||||
|
|
||||||
|
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = torch.cat((reference_latent, x), dim=2)
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
scail_pose_seq_len = 0
|
||||||
|
if pose_latents is not None:
|
||||||
|
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||||
|
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||||
|
scail_pose_seq_len = scail_x.shape[1]
|
||||||
|
x = torch.cat([x, scail_x], dim=1)
|
||||||
|
del scail_x
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||||
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# context
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None:
|
||||||
|
if self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
context = torch.cat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
|
return out
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
if scail_pose_seq_len > 0:
|
||||||
|
x = x[:, :-scail_pose_seq_len]
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
|
||||||
|
if reference_latent is not None:
|
||||||
|
x = x[:, :, reference_latent.shape[2]:]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||||
|
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if pose_latents is None:
|
||||||
|
return main_freqs
|
||||||
|
|
||||||
|
ref_t_patches = 0
|
||||||
|
if reference_latent is not None:
|
||||||
|
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||||
|
|
||||||
|
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||||
|
|
||||||
|
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||||
|
h_scale = h / H_pose
|
||||||
|
w_scale = w / W_pose
|
||||||
|
|
||||||
|
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||||
|
h_shift = (h_scale - 1) / 2
|
||||||
|
w_shift = (w_scale - 1) / 2
|
||||||
|
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||||
|
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||||
|
|
||||||
|
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||||
|
|
||||||
|
t_len = t
|
||||||
|
if time_dim_concat is not None:
|
||||||
|
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
|
t_len = x.shape[2]
|
||||||
|
|
||||||
|
reference_latent = None
|
||||||
|
if "reference_latent" in kwargs:
|
||||||
|
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||||
|
t_len += reference_latent.shape[2]
|
||||||
|
|
||||||
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||||
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|||||||
@ -1263,6 +1263,11 @@ class Lumina2(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ZImagePixelSpace(Lumina2):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
@ -1502,6 +1507,44 @@ class WAN21_FlowRVS(WAN21):
|
|||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
self.image_to_video = image_to_video
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||||
|
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||||
|
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||||
|
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||||
|
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
pose_latents = self.process_latent_in(pose_latents)
|
||||||
|
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||||
|
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||||
|
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
|
if pose_latents is not None:
|
||||||
|
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||||
|
|||||||
@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "lumina2"
|
dit_config["image_model"] = "lumina2"
|
||||||
dit_config["patch_size"] = 2
|
dit_config["patch_size"] = 2
|
||||||
@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if sig_weight is not None:
|
if sig_weight is not None:
|
||||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||||
|
|
||||||
|
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
||||||
|
if dec_cond_key in state_dict_keys: # pixel-space variant
|
||||||
|
dit_config["image_model"] = "zimage_pixel"
|
||||||
|
# patch_size and in_channels are derived from x_embedder:
|
||||||
|
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
|
||||||
|
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
|
||||||
|
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
|
||||||
|
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
|
||||||
|
# patch_size: infer from decoder final layer output matching x_embedder input
|
||||||
|
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
|
||||||
|
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
|
||||||
|
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
|
||||||
|
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
|
||||||
|
dit_config["in_channels"] = 3
|
||||||
|
dit_config["decoder_in_channels"] = dec_in_ch
|
||||||
|
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
|
||||||
|
dit_config["decoder_num_res_blocks"] = count_blocks(
|
||||||
|
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
||||||
|
)
|
||||||
|
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
|
||||||
|
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["use_x0"] = True
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
@ -498,6 +521,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["model_type"] = "humo"
|
dit_config["model_type"] = "humo"
|
||||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "animate"
|
dit_config["model_type"] = "animate"
|
||||||
|
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "scail"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
@ -531,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1
|
||||||
|
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "hunyuan3d2_1"
|
dit_config["image_model"] = "hunyuan3d2_1"
|
||||||
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
|
||||||
@ -1053,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||||
|
elif 'noise_refiner.0.attention.norm_k.weight' in state_dict:
|
||||||
|
n_layers = count_blocks(state_dict, 'layers.{}.')
|
||||||
|
dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0]
|
||||||
|
sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix)
|
||||||
|
for k in state_dict: # For zeta chroma
|
||||||
|
if k not in sd_map:
|
||||||
|
sd_map[k] = k
|
||||||
elif 'x_embedder.weight' in state_dict: #Flux
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
|||||||
@ -32,9 +32,6 @@ import comfy.memory_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
|
||||||
import comfy_aimdo.torch
|
|
||||||
import comfy_aimdo.model_vbar
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||||
@ -180,6 +177,14 @@ def is_ixuca():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_wsl():
|
||||||
|
version = platform.uname().release
|
||||||
|
if version.endswith("-Microsoft"):
|
||||||
|
return True
|
||||||
|
elif version.endswith("microsoft-standard-WSL2"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -631,12 +636,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
ram_to_free = ram_required - get_free_ram()
|
ram_to_free = ram_required - get_free_ram()
|
||||||
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#as that works on-demand.
|
||||||
#as that works on-demand.
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
memory_to_free = 0
|
||||||
memory_to_free = 0
|
|
||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
@ -1199,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
if hasattr(weight, "_v"):
|
|
||||||
#Unexpected usage patterns. There is no reason these don't work but they
|
|
||||||
#have no testing and no callers do this.
|
|
||||||
assert r is None
|
|
||||||
assert stream is None
|
|
||||||
|
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
|
||||||
|
|
||||||
if dtype is None:
|
|
||||||
dtype = weight._model_dtype
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
|
||||||
if signature is not None:
|
|
||||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
|
||||||
v_tensor = weight._v_tensor
|
|
||||||
else:
|
|
||||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
|
||||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
|
||||||
weight._v_tensor = v_tensor
|
|
||||||
weight._v_signature = signature
|
|
||||||
#Send it over
|
|
||||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return v_tensor.to(dtype=dtype)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
|
||||||
#Offloaded casting could skip this, however it would make the quantizations
|
|
||||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
|
||||||
#that would happen in regular flow to make offload deterministic.
|
|
||||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
|
||||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
|
||||||
weight = cast_buffer
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
return r
|
|
||||||
|
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
|||||||
@ -308,15 +308,22 @@ class ModelPatcher:
|
|||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False):
|
def get_clone_model_override(self):
|
||||||
|
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||||
|
|
||||||
|
def clone(self, disable_dynamic=False, model_override=None):
|
||||||
class_ = self.__class__
|
class_ = self.__class__
|
||||||
model = self.model
|
|
||||||
if self.is_dynamic() and disable_dynamic:
|
if self.is_dynamic() and disable_dynamic:
|
||||||
class_ = ModelPatcher
|
class_ = ModelPatcher
|
||||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
if model_override is None:
|
||||||
model = temp_model_patcher.model
|
if self.cached_patcher_init is None:
|
||||||
|
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||||
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
model_override = temp_model_patcher.get_clone_model_override()
|
||||||
|
if model_override is None:
|
||||||
|
model_override = self.get_clone_model_override()
|
||||||
|
|
||||||
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -325,13 +332,12 @@ class ModelPatcher:
|
|||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
n.backup = self.backup
|
|
||||||
n.object_patches_backup = self.object_patches_backup
|
|
||||||
n.parent = self
|
n.parent = self
|
||||||
n.pinned = self.pinned
|
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
|
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
for k in self.attachments:
|
for k in self.attachments:
|
||||||
@ -1429,12 +1435,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
|
||||||
#and non-dynamic patchers.
|
|
||||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
|
||||||
del self.model.model_loaded_weight_memory
|
|
||||||
if not hasattr(self.model, "dynamic_vbars"):
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
|
self.non_dynamic_delegate_model = None
|
||||||
assert load_device is not None
|
assert load_device is not None
|
||||||
|
|
||||||
def is_dynamic(self):
|
def is_dynamic(self):
|
||||||
@ -1454,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
if vbar is None:
|
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
|
||||||
return 0
|
|
||||||
return vbar.loaded_size()
|
|
||||||
|
|
||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||||
@ -1497,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
num_patches = 0
|
num_patches = 0
|
||||||
allocated_size = 0
|
allocated_size = 0
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
@ -1505,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
|
||||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
|
||||||
#with a high layer rate (e.g. autoregressive LLMs).
|
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
@ -1551,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if key in self.backup:
|
if key in self.backup:
|
||||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
|
if weight is not None:
|
||||||
|
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1576,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for param in params:
|
for param in params:
|
||||||
key = key_param_name_to_key(n, param)
|
key = key_param_name_to_key(n, param)
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
weight.seed_key = key
|
if key not in self.backup:
|
||||||
set_dirty(weight, dirty)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||||
geometry = weight
|
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
|
||||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
|
||||||
weight_size = geometry.numel() * geometry.element_size()
|
|
||||||
if vbar is not None and not hasattr(weight, "_v"):
|
|
||||||
weight._v = vbar.alloc(weight_size)
|
|
||||||
weight._model_dtype = model_dtype
|
|
||||||
allocated_size += weight_size
|
|
||||||
vbar.set_watermark_limit(allocated_size)
|
|
||||||
|
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||||
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||||
|
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||||
@ -1606,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
assert self.load_device != torch.device("cpu")
|
assert self.load_device != torch.device("cpu")
|
||||||
|
|
||||||
vbar = self._vbar_get()
|
vbar = self._vbar_get()
|
||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
|
if freed < memory_to_free:
|
||||||
|
for key in list(self.backup.keys()):
|
||||||
|
bk = self.backup.pop(key)
|
||||||
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
freed += self.model.model_loaded_weight_memory
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
return freed
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
@ -1633,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
|
||||||
for k in keys:
|
|
||||||
bk = self.backup[k]
|
|
||||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
@ -1669,4 +1668,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_non_dynamic_delegate(self):
|
||||||
|
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
||||||
|
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
||||||
|
return model_patcher
|
||||||
|
|
||||||
|
|
||||||
CoreModelPatcher = ModelPatcher
|
CoreModelPatcher = ModelPatcher
|
||||||
|
|||||||
@ -66,6 +66,18 @@ def convert_cond(cond):
|
|||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def cond_has_hooks(cond):
|
||||||
|
for c in cond:
|
||||||
|
temp = c[1]
|
||||||
|
if "hooks" in temp:
|
||||||
|
return True
|
||||||
|
if "control" in temp:
|
||||||
|
control = temp["control"]
|
||||||
|
extra_hooks = control.get_extra_hooks()
|
||||||
|
if len(extra_hooks) > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
|
|||||||
@ -946,6 +946,8 @@ class CFGGuider:
|
|||||||
|
|
||||||
def inner_set_conds(self, conds):
|
def inner_set_conds(self, conds):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
|
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
||||||
|
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|||||||
38
comfy/sd.py
38
comfy/sd.py
@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -233,7 +233,8 @@ class CLIP:
|
|||||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
#Match torch.float32 hardcode upcast in TE implemention
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
self.patcher.set_model_compute_dtype(torch.float32)
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
@ -267,9 +268,9 @@ class CLIP:
|
|||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
self.tokenizer_options = {}
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self, disable_dynamic=False):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
n.patcher = self.patcher.clone()
|
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
|
||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
@ -1164,14 +1165,21 @@ class CLIPType(Enum):
|
|||||||
LONGCAT_IMAGE = 26
|
LONGCAT_IMAGE = 26
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
|
||||||
|
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
|
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if model_options.get("custom_operations", None) is None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
|
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
class TEModel(Enum):
|
class TEModel(Enum):
|
||||||
@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if output_model:
|
if output_model and out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
|
if output_clip and out[1] is not None:
|
||||||
|
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
|
|||||||
disable_dynamic=disable_dynamic)
|
disable_dynamic=disable_dynamic)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
|
||||||
|
embedding_directory=embedding_directory, output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
|
|||||||
@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
class ZImagePixelSpace(ZImage):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "zimage_pixel",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pixel-space model: no spatial compression, operates on raw RGB patches.
|
||||||
|
latent_format = latent_formats.ZImagePixelSpace
|
||||||
|
|
||||||
|
# Much lower memory than latent-space models (no VAE, small patches).
|
||||||
|
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.ZImagePixelSpace(self, device=device)
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1268,6 +1282,16 @@ class WAN21_FlowRVS(WAN21_T2V):
|
|||||||
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_SCAIL(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "scail",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1710,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
if model == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||||
model = "gemini-3.1-flash-image-preview"
|
model = "gemini-3.1-flash-image-preview"
|
||||||
if response_modalities == "IMAGE+TEXT":
|
|
||||||
raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.")
|
|
||||||
|
|
||||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"response_modalities",
|
"response_modalities",
|
||||||
options=["IMAGE"],
|
options=["IMAGE", "IMAGE+TEXT"],
|
||||||
advanced=True,
|
advanced=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
|
IO.String.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class JobStatus:
|
|||||||
|
|
||||||
|
|
||||||
# Media types that can be previewed in the frontend
|
# Media types that can be previewed in the frontend
|
||||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||||
|
|
||||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||||
@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict:
|
|||||||
normalized[node_id] = normalized_node
|
normalized[node_id] = normalized_node
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
# Text preview truncation limit (1024 characters) to prevent preview_output bloat
|
||||||
|
TEXT_PREVIEW_MAX_LENGTH = 1024
|
||||||
|
|
||||||
|
|
||||||
|
def _create_text_preview(value: str) -> dict:
|
||||||
|
"""Create a text preview dict with optional truncation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with 'content' and optionally 'truncated' flag
|
||||||
|
"""
|
||||||
|
if len(value) <= TEXT_PREVIEW_MAX_LENGTH:
|
||||||
|
return {'content': value}
|
||||||
|
return {
|
||||||
|
'content': value[:TEXT_PREVIEW_MAX_LENGTH],
|
||||||
|
'truncated': True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
normalized = normalize_output_item(item)
|
if not isinstance(item, dict):
|
||||||
if normalized is None:
|
# Handle text outputs (non-dict items like strings or tuples)
|
||||||
continue
|
normalized = normalize_output_item(item)
|
||||||
|
if normalized is None:
|
||||||
|
# Not a 3D file string — check for text preview
|
||||||
|
if media_type == 'text':
|
||||||
|
count += 1
|
||||||
|
if preview_output is None:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
text_value = item[0] if item else ''
|
||||||
|
else:
|
||||||
|
text_value = str(item)
|
||||||
|
text_preview = _create_text_preview(text_value)
|
||||||
|
enriched = {
|
||||||
|
**text_preview,
|
||||||
|
'nodeId': node_id,
|
||||||
|
'mediaType': media_type
|
||||||
|
}
|
||||||
|
if fallback_preview is None:
|
||||||
|
fallback_preview = enriched
|
||||||
|
continue
|
||||||
|
# normalize_output_item returned a dict (e.g. 3D file)
|
||||||
|
item = normalized
|
||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
if preview_output is not None:
|
if preview_output is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
if is_previewable(media_type, item):
|
||||||
enriched = {
|
enriched = {
|
||||||
**normalized,
|
**item,
|
||||||
'nodeId': node_id,
|
'nodeId': node_id,
|
||||||
}
|
}
|
||||||
if 'mediaType' not in normalized:
|
if 'mediaType' not in item:
|
||||||
enriched['mediaType'] = media_type
|
enriched['mediaType'] = media_type
|
||||||
if normalized.get('type') == 'output':
|
if item.get('type') == 'output':
|
||||||
preview_output = enriched
|
preview_output = enriched
|
||||||
elif fallback_preview is None:
|
elif fallback_preview is None:
|
||||||
fallback_preview = enriched
|
fallback_preview = enriched
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
|
|||||||
|
|
||||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||||
if tile is not None:
|
if tile is not None:
|
||||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||||
else:
|
else:
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
|
|
||||||
|
|||||||
@ -248,7 +248,7 @@ class SetClipHooks:
|
|||||||
|
|
||||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
clip = clip.clone()
|
clip = clip.clone(disable_dynamic=True)
|
||||||
if apply_to_conds:
|
if apply_to_conds:
|
||||||
clip.apply_hooks_to_conds = hooks
|
clip.apply_hooks_to_conds = hooks
|
||||||
clip.patcher.forced_hooks = hooks.clone()
|
clip.patcher.forced_hooks = hooks.clone()
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Mahiro CFG",
|
display_name="Positive-Biased Guidance",
|
||||||
category="_for_testing",
|
category="_for_testing",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode):
|
|||||||
io.Model.Output(display_name="patched_model"),
|
io.Model.Output(display_name="patched_model"),
|
||||||
],
|
],
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
|
search_aliases=[
|
||||||
|
"mahiro",
|
||||||
|
"mahiro cfg",
|
||||||
|
"similarity-adaptive guidance",
|
||||||
|
"positive-biased cfg",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model) -> io.NodeOutput:
|
def execute(cls, model) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def mahiro_normd(args):
|
def mahiro_normd(args):
|
||||||
scale: float = args['cond_scale']
|
scale: float = args["cond_scale"]
|
||||||
cond_p: torch.Tensor = args['cond_denoised']
|
cond_p: torch.Tensor = args["cond_denoised"]
|
||||||
uncond_p: torch.Tensor = args['uncond_denoised']
|
uncond_p: torch.Tensor = args["uncond_denoised"]
|
||||||
#naive leap
|
# naive leap
|
||||||
leap = cond_p * scale
|
leap = cond_p * scale
|
||||||
#sim with uncond leap
|
# sim with uncond leap
|
||||||
u_leap = uncond_p * scale
|
u_leap = uncond_p * scale
|
||||||
cfg = args["denoised"]
|
cfg = args["denoised"]
|
||||||
merge = (leap + cfg) / 2
|
merge = (leap + cfg) / 2
|
||||||
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
normu = torch.sqrt(u_leap.abs()) * u_leap.sign()
|
||||||
normm = torch.sqrt(merge.abs()) * merge.sign()
|
normm = torch.sqrt(merge.abs()) * merge.sign()
|
||||||
sim = F.cosine_similarity(normu, normm).mean()
|
sim = F.cosine_similarity(normu, normm).mean()
|
||||||
simsc = 2 * (sim+1)
|
simsc = 2 * (sim + 1)
|
||||||
wm = (simsc*cfg + (4-simsc)*leap) / 4
|
wm = (simsc * cfg + (4 - simsc) * leap) / 4
|
||||||
return wm
|
return wm
|
||||||
|
|
||||||
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
m.set_model_sampler_post_cfg_function(mahiro_normd)
|
||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|||||||
@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||||
|
|
||||||
|
|
||||||
|
class WanSCAILToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanSCAILToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("reference_image", optional=True),
|
||||||
|
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||||
|
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||||
|
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||||
|
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
ref_latent = None
|
||||||
|
if reference_image is not None:
|
||||||
|
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||||
|
|
||||||
|
if ref_latent is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
if pose_video is not None:
|
||||||
|
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||||
|
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||||
|
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanAnimateToVideo,
|
WanAnimateToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
WanInfiniteTalkToVideo,
|
WanInfiniteTalkToVideo,
|
||||||
|
WanSCAILToVideo,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
|||||||
12
main.py
12
main.py
@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state
|
|||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
import comfy_aimdo.control
|
|
||||||
|
|
||||||
if enables_dynamic_vram():
|
|
||||||
comfy_aimdo.control.init()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||||
@ -28,6 +23,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
import comfy_aimdo.control
|
||||||
|
|
||||||
|
if enables_dynamic_vram():
|
||||||
|
comfy_aimdo.control.init()
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ import hook_breaker_ac10a0
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
if enables_dynamic_vram():
|
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False):
|
|||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0):
|
||||||
|
"""
|
||||||
|
Apply values to conditioning only during [start_percent, end_percent], keeping the
|
||||||
|
original conditioning active outside that range. Respects existing per-entry ranges.
|
||||||
|
"""
|
||||||
|
if start_percent > end_percent:
|
||||||
|
logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})")
|
||||||
|
return conditioning
|
||||||
|
|
||||||
|
EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
cond_start = t[1].get("start_percent", 0.0)
|
||||||
|
cond_end = t[1].get("end_percent", 1.0)
|
||||||
|
intersect_start = max(start_percent, cond_start)
|
||||||
|
intersect_end = min(end_percent, cond_end)
|
||||||
|
|
||||||
|
if intersect_start >= intersect_end: # no overlap: emit unchanged
|
||||||
|
c.append(t)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if intersect_start > cond_start: # part before the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS}))
|
||||||
|
|
||||||
|
c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end}))
|
||||||
|
|
||||||
|
if intersect_end < cond_end: # part after the requested range
|
||||||
|
c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end}))
|
||||||
|
return c
|
||||||
|
|
||||||
def pillow(fn, arg):
|
def pillow(fn, arg):
|
||||||
prev_value = None
|
prev_value = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.39.19
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.4
|
comfyui-workflow-templates==0.9.5
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.2
|
comfy-aimdo>=0.2.4
|
||||||
requests
|
requests
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
|
|||||||
@ -49,6 +49,12 @@ def mock_provider(mock_releases):
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
import utils.install_util
|
||||||
|
utils.install_util.PACKAGE_VERSIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def test_get_release(mock_provider, mock_releases):
|
def test_get_release(mock_provider, mock_releases):
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
release = mock_provider.get_release(version)
|
release = mock_provider.get_release(version)
|
||||||
|
|||||||
@ -38,13 +38,13 @@ class TestIsPreviewable:
|
|||||||
"""Unit tests for is_previewable()"""
|
"""Unit tests for is_previewable()"""
|
||||||
|
|
||||||
def test_previewable_media_types(self):
|
def test_previewable_media_types(self):
|
||||||
"""Images, video, audio, 3d media types should be previewable."""
|
"""Images, video, audio, 3d, text media types should be previewable."""
|
||||||
for media_type in ['images', 'video', 'audio', '3d']:
|
for media_type in ['images', 'video', 'audio', '3d', 'text']:
|
||||||
assert is_previewable(media_type, {}) is True
|
assert is_previewable(media_type, {}) is True
|
||||||
|
|
||||||
def test_non_previewable_media_types(self):
|
def test_non_previewable_media_types(self):
|
||||||
"""Other media types should not be previewable."""
|
"""Other media types should not be previewable."""
|
||||||
for media_type in ['latents', 'text', 'metadata', 'files']:
|
for media_type in ['latents', 'metadata', 'files']:
|
||||||
assert is_previewable(media_type, {}) is False
|
assert is_previewable(media_type, {}) is False
|
||||||
|
|
||||||
def test_3d_extensions_previewable(self):
|
def test_3d_extensions_previewable(self):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
# The path to the requirements.txt file
|
# The path to the requirements.txt file
|
||||||
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
requirements_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running:
|
|||||||
{sys.executable} {extra}-m pip install -r {requirements_path}
|
{sys.executable} {extra}-m pip install -r {requirements_path}
|
||||||
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_version(version: str) -> bool:
|
||||||
|
"""Validate if a string is a valid semantic version (X.Y.Z format)."""
|
||||||
|
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
|
||||||
|
return bool(re.match(pattern, version))
|
||||||
|
|
||||||
|
|
||||||
|
PACKAGE_VERSIONS = {}
|
||||||
|
def get_required_packages_versions():
|
||||||
|
if len(PACKAGE_VERSIONS) > 0:
|
||||||
|
return PACKAGE_VERSIONS.copy()
|
||||||
|
out = PACKAGE_VERSIONS
|
||||||
|
try:
|
||||||
|
with open(requirements_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip().replace(">=", "==")
|
||||||
|
s = line.split("==")
|
||||||
|
if len(s) == 2:
|
||||||
|
version_str = s[-1]
|
||||||
|
if not is_valid_version(version_str):
|
||||||
|
logging.error(f"Invalid version format in requirements.txt: {version_str}")
|
||||||
|
continue
|
||||||
|
out[s[0]] = version_str
|
||||||
|
return out.copy()
|
||||||
|
except FileNotFoundError:
|
||||||
|
logging.error("requirements.txt not found.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading requirements.txt: {e}")
|
||||||
|
return None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user