diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 6556677e0..95cc48f88 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -16,7 +16,7 @@ body: ## Very Important - Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. + Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: diff --git a/README.md b/README.md index d97c6181d..56b7966cf 100644 --- a/README.md +++ b/README.md @@ -189,8 +189,6 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). - [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? diff --git a/app/frontend_management.py b/app/frontend_management.py index bdaa85812..f753ef0de 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,7 +17,7 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired -from utils.install_util import get_missing_requirements_message, requirements_path +from utils.install_util import get_missing_requirements_message, get_required_packages_versions from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger @@ -45,25 +45,7 @@ def get_installed_frontend_version(): def get_required_frontend_version(): - """Get the required frontend version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-frontend-package=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-frontend-package not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required frontend version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-frontend-package", None) def check_frontend_version(): @@ -217,25 +199,7 @@ class FrontendManager: @classmethod def get_required_templates_version(cls) -> str: - """Get the required workflow templates version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-workflow-templates=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid templates version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-workflow-templates not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required templates version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-workflow-templates", None) @classmethod def default_frontend_path(cls) -> str: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 32296db22..2af2d56cf 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -147,6 +147,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") +parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -160,7 +161,6 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -261,4 +261,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2f82d51da..b54f7f39a 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC): mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: - raise Exception("No sample_sigmas matched current timestep; something went wrong.") + return # substep from multi-step sampler: keep self._step from the last full step self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 2b080aaeb..553fd5b38 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module): def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, - v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module): vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa del vshift_msa, vscale_msa - attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options) del norm_vx # video cross-attention vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] @@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel): return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] def _process_transformer_blocks( - self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs ): vx = x[0] ax = x[1] @@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=args["v_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], + self_attention_mask=args.get("self_attention_mask"), ) return out @@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel): "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, + "self_attention_mask": self_attention_mask, }, {"original_block": block_wrap}, ) @@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel): v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index d61e19d6e..60d760d29 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import functools +import logging import math from typing import Dict, Optional, Tuple @@ -14,6 +15,8 @@ import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +logger = logging.getLogger(__name__) + def _log_base(x, base): return np.log(x) / np.log(base) @@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) x.addcmul_(attn1_input, gate_msa) del attn1_input @@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC): """Process input data. Must be implemented by subclasses.""" pass + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Base implementation returns None (no attenuation). Subclasses that + support guide-based attention control should override this. + """ + return None + @abstractmethod - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): """Process transformer blocks. Must be implemented by subclasses.""" pass @@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC): attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + # Build self-attention mask for per-guide attenuation + self_attention_mask = self._build_guide_self_attention_mask( + x, transformer_options, merged_args + ) + # Process transformer blocks x = self._process_transformer_blocks( - x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + x, context, attention_mask, timestep, pe, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + **merged_args, ) # Process output @@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel): pixel_coords = pixel_coords[:, :, grid_mask, ...] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + + # Compute per-guide surviving token counts from guide_attention_entries. + # Each entry tracks one guide reference; they are appended in order and + # their pre_filter_counts partition the kf_grid_mask. + guide_entries = kwargs.get("guide_attention_entries", None) + if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + if total_pfc != len(kf_grid_mask): + raise ValueError( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) + resolved_entries = [] + offset = 0 + for entry in guide_entries: + pfc = entry["pre_filter_count"] + entry_mask = kf_grid_mask[offset:offset + pfc] + surviving = int(entry_mask.sum().item()) + resolved_entries.append({ + **entry, + "surviving_count": surviving, + }) + offset += pfc + additional_args["resolved_guide_entries"] = resolved_entries + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + # Total surviving guide tokens (all guides) + additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] + x = self.patchify_proj(x) return x, pixel_coords, additional_args - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Reads resolved_guide_entries from merged_args (computed in _process_input) + to build a log-space additive bias mask that attenuates noisy ↔ guide + attention for each guide reference independently. + + Returns None if no attenuation is needed (all strengths == 1.0 and no + spatial masks, or no guide tokens). + """ + if isinstance(x, list): + # AV model: x = [vx, ax]; use vx for token count and device + total_tokens = x[0].shape[1] + device = x[0].device + dtype = x[0].dtype + else: + total_tokens = x.shape[1] + device = x.device + dtype = x.dtype + + num_guide_tokens = merged_args.get("num_guide_tokens", 0) + if num_guide_tokens == 0: + return None + + resolved_entries = merged_args.get("resolved_guide_entries", None) + if not resolved_entries: + return None + + # Check if any attenuation is actually needed + needs_attenuation = any( + e["strength"] < 1.0 or e.get("pixel_mask") is not None + for e in resolved_entries + ) + if not needs_attenuation: + return None + + # Build per-guide-token weights for all tracked guide tokens. + # Guides are appended in order at the end of the sequence. + guide_start = total_tokens - num_guide_tokens + all_weights = [] + total_tracked = 0 + + for entry in resolved_entries: + surviving = entry["surviving_count"] + if surviving == 0: + continue + + strength = entry["strength"] + pixel_mask = entry.get("pixel_mask") + latent_shape = entry.get("latent_shape") + + if pixel_mask is not None and latent_shape is not None: + f_lat, h_lat, w_lat = latent_shape + per_token = self._downsample_mask_to_latent( + pixel_mask.to(device=device, dtype=dtype), + f_lat, h_lat, w_lat, + ) + # per_token shape: (B, f_lat*h_lat*w_lat). + # Collapse batch dim — the mask is assumed identical across the + # batch; validate and take the first element to get (1, tokens). + if per_token.shape[0] > 1: + ref = per_token[0] + for bi in range(1, per_token.shape[0]): + if not torch.equal(ref, per_token[bi]): + logger.warning( + "pixel_mask differs across batch elements; " + "using first element only." + ) + break + per_token = per_token[:1] + # `surviving` is the post-grid_mask token count. + # Clamp to surviving to handle any mismatch safely. + n_weights = min(per_token.shape[1], surviving) + weights = per_token[:, :n_weights] * strength # (1, n_weights) + else: + weights = torch.full( + (1, surviving), strength, device=device, dtype=dtype + ) + + all_weights.append(weights) + total_tracked += weights.shape[1] + + if not all_weights: + return None + + # Concatenate per-token weights for all tracked guides + tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) + + # Check if any weight is actually < 1.0 (otherwise no attenuation needed) + if (tracked_weights >= 1.0).all(): + return None + + # Build the mask: guide tokens are at the end of the sequence. + # Tracked guides come first (in order), untracked follow. + return self._build_self_attention_mask( + total_tokens, num_guide_tokens, total_tracked, + tracked_weights, guide_start, device, dtype, + ) + + @staticmethod + def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): + """Downsample a pixel-space mask to per-token latent weights. + + Args: + mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. + f_lat: Number of latent frames (pre-dilation original count). + h_lat: Latent height (pre-dilation original height). + w_lat: Latent width (pre-dilation original width). + + Returns: + (B, F_lat * H_lat * W_lat) flattened per-token weights. + """ + b = mask.shape[0] + f_pix = mask.shape[2] + + # Spatial downsampling: area interpolation per frame + spatial_down = torch.nn.functional.interpolate( + rearrange(mask, "b 1 f h w -> (b f) 1 h w"), + size=(h_lat, w_lat), + mode="area", + ) + spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) + + # Temporal downsampling: first pixel frame maps to first latent frame, + # remaining pixel frames are averaged in groups for causal temporal structure. + first_frame = spatial_down[:, :, :1, :, :] + if f_pix > 1 and f_lat > 1: + remaining_pix = f_pix - 1 + remaining_lat = f_lat - 1 + t = remaining_pix // remaining_lat + if t < 1: + # Fewer pixel frames than latent frames — upsample by repeating + # the available pixel frames via nearest interpolation. + rest_flat = rearrange( + spatial_down[:, :, 1:, :, :], + "b 1 f h w -> (b h w) 1 f", + ) + rest_up = torch.nn.functional.interpolate( + rest_flat, size=remaining_lat, mode="nearest", + ) + rest = rearrange( + rest_up, "(b h w) 1 f -> b 1 f h w", + b=b, h=h_lat, w=w_lat, + ) + else: + # Trim trailing pixel frames that don't fill a complete group + usable = remaining_lat * t + rest = rearrange( + spatial_down[:, :, 1:1 + usable, :, :], + "b 1 (f t) h w -> b 1 f t h w", + t=t, + ) + rest = rest.mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + elif f_lat > 1: + # Single pixel frame but multiple latent frames — repeat the + # single frame across all latent frames. + latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) + else: + latent_mask = first_frame + + return rearrange(latent_mask, "b 1 f h w -> b (f h w)") + + @staticmethod + def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, + tracked_weights, guide_start, device, dtype): + """Build a log-space additive self-attention bias mask. + + Attenuates attention between noisy tokens and tracked guide tokens. + Untracked guide tokens (at the end of the guide portion) keep full attention. + + Args: + total_tokens: Total sequence length. + num_guide_tokens: Total guide tokens (all guides) at end of sequence. + tracked_count: Number of tracked guide tokens (first in the guide portion). + tracked_weights: (1, tracked_count) tensor, values in [0, 1]. + guide_start: Index where guide tokens begin in the sequence. + device: Target device. + dtype: Target dtype. + + Returns: + (1, 1, total_tokens, total_tokens) additive bias mask. + 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. + """ + finfo = torch.finfo(dtype) + mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + tracked_end = guide_start + tracked_count + + # Convert weights to log-space bias + w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) + log_w = torch.full_like(w, finfo.min) + positive_mask = w > 0 + if positive_mask.any(): + log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) + + # noisy → tracked guides: each noisy row gets the same per-guide weight + mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) + # tracked guides → noisy: each guide row broadcasts its weight across noisy cols + mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + + return mask + + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel): def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel): timestep=timestep, pe=pe, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, ) return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..295310df6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -18,6 +18,8 @@ import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init +from ..sdpose import HeatmapHead + class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. @@ -441,6 +443,7 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, attn_precision=None, + heatmap_head=False, device=None, operations=ops, ): @@ -827,6 +830,9 @@ class UNetModel(nn.Module): #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) + if heatmap_head: + self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations) + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, diff --git a/comfy/ldm/modules/sdpose.py b/comfy/ldm/modules/sdpose.py new file mode 100644 index 000000000..d67b60b76 --- /dev/null +++ b/comfy/ldm/modules/sdpose.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +from scipy.ndimage import gaussian_filter + +class HeatmapHead(torch.nn.Module): + def __init__( + self, + in_channels=640, + out_channels=133, + input_size=(768, 1024), + heatmap_scale=4, + deconv_out_channels=(640,), + deconv_kernel_sizes=(4,), + conv_out_channels=(640,), + conv_kernel_sizes=(1,), + final_layer_kernel_size=1, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale) + self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32) + + # Deconv layers + if deconv_out_channels: + deconv_layers = [] + for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes): + if kernel_size == 4: + padding, output_padding = 1, 0 + elif kernel_size == 3: + padding, output_padding = 1, 1 + elif kernel_size == 2: + padding, output_padding = 0, 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size}') + + deconv_layers.extend([ + operations.ConvTranspose2d(in_channels, out_ch, kernel_size, + stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.deconv_layers = torch.nn.Sequential(*deconv_layers) + else: + self.deconv_layers = torch.nn.Identity() + + # Conv layers + if conv_out_channels: + conv_layers = [] + for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes): + padding = (kernel_size - 1) // 2 + conv_layers.extend([ + operations.Conv2d(in_channels, out_ch, kernel_size, + stride=1, padding=padding, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.conv_layers = torch.nn.Sequential(*conv_layers) + else: + self.conv_layers = torch.nn.Identity() + + self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype) + + def forward(self, x): # Decode heatmaps to keypoints + heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x))) + heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W) + B, K, H, W = heatmaps_np.shape + + batch_keypoints = [] + batch_scores = [] + + for b in range(B): + hm = heatmaps_np[b].copy() # (K, H, W) + + # --- vectorised argmax --- + flat = hm.reshape(K, -1) + idx = np.argmax(flat, axis=1) + scores = flat[np.arange(K), idx].copy() + y_locs, x_locs = np.unravel_index(idx, (H, W)) + keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space + invalid = scores <= 0. + keypoints[invalid] = -1 + + # --- DARK sub-pixel refinement (UDP) --- + # 1. Gaussian blur with max-preserving normalisation + border = 5 # (kernel-1)//2 for kernel=11 + for k in range(K): + origin_max = np.max(hm[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = hm[k].copy() + dr = gaussian_filter(dr, sigma=2.0) + hm[k] = dr[border:-border, border:-border].copy() + cur_max = np.max(hm[k]) + if cur_max > 0: + hm[k] *= origin_max / cur_max + # 2. Log-space for Taylor expansion + np.clip(hm, 1e-3, 50., hm) + np.log(hm, hm) + # 3. Hessian-based Newton step + hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = hm_pad[index] + ix1 = hm_pad[index + 1] + iy1 = hm_pad[index + W + 2] + ix1y1 = hm_pad[index + W + 3] + ix1_y1_ = hm_pad[index - W - 3] + ix1_ = hm_pad[index - 1] + iy1_ = hm_pad[index - 2 - W] + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1) + + # --- restore to input image space --- + keypoints = keypoints * self.scale_factor + keypoints[invalid] = -1 + + batch_keypoints.append(keypoints) + batch_scores.append(scores) + + return batch_keypoints, batch_scores diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ea123acb4..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + + if pose_latents is None: + return main_freqs + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose + + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) + + return torch.cat([main_freqs, pose_freqs], dim=1) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if pose_latents is not None: + pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + reference_latent = None + if "reference_latent" in kwargs: + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index fd125ceed..71f73c64e 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -459,6 +459,7 @@ class WanVAE(nn.Module): attn_scales=[], temperal_downsample=[True, True, False], image_channels=3, + conv_out_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -474,7 +475,7 @@ class WanVAE(nn.Module): attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): @@ -484,7 +485,7 @@ class WanVAE(nn.Module): iter_ = 1 + (t - 1) // 4 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_conv3d(self.encoder) ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): conv_idx = [0] diff --git a/comfy/lora.py b/comfy/lora.py index 279cf38bb..f36ddb046 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -337,6 +337,7 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): key_lora = k[len("diffusion_model.decoder."):-len(".weight")] key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 2f49578f6..a1c690b9b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -76,6 +76,7 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 + IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -108,6 +109,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow + elif model_type == ModelType.IMG_TO_IMG_FLOW: + c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -922,6 +925,25 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class LongCatImage(Flux): + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + transformer_options = transformer_options.copy() + rope_opts = transformer_options.get("rope_options", {}) + rope_opts = dict(rope_opts) + rope_opts.setdefault("shift_t", 1.0) + rope_opts.setdefault("shift_y", 512.0) + rope_opts.setdefault("shift_x", 512.0) + transformer_options["rope_options"] = rope_opts + return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def encode_adm(self, **kwargs): + return None + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out.pop('guidance', None) + return out + class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -971,6 +993,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1023,6 +1049,10 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): @@ -1466,6 +1496,50 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_FlowRVS(WAN21): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): + model_config.unet_config["model_type"] = "t2v" + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 30ea03e8e..3faa950ca 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] + if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image + dit_config["txt_ids_dims"] = [1, 2] return dit_config @@ -496,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" @@ -509,6 +513,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if ref_conv_weight is not None: dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D @@ -792,6 +799,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["use_temporal_resblock"] = False unet_config["use_temporal_attention"] = False + heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix) + if heatmap_key in state_dict_keys: + unet_config["heatmap_head"] = True + return unet_config def model_config_from_unet_config(unet_config, state_dict=None): @@ -1012,7 +1023,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], - 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} diff --git a/comfy/model_management.py b/comfy/model_management.py index 1fe56a62b..c817d43b5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -180,6 +180,14 @@ def is_ixuca(): return True return False +def is_wsl(): + version = platform.uname().release + if version.endswith("-Microsoft"): + return True + elif version.endswith("microsoft-standard-WSL2"): + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -350,7 +358,7 @@ AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD @@ -378,7 +386,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]): @@ -631,12 +639,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) ram_to_free = ram_required - get_free_ram() - - if current_loaded_models[i].model.is_dynamic() and for_dynamic: - #don't actually unload dynamic models for the sake of other dynamic models - #as that works on-demand. - memory_required -= current_loaded_models[i].model.loaded_size() - memory_to_free = 0 + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c9ba8096..3fc76d9db 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -308,15 +308,22 @@ class ModelPatcher: def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self, disable_dynamic=False): + def get_clone_model_override(self): + return self.model, (self.backup, self.object_patches_backup, self.pinned) + + def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ - model = self.model if self.is_dynamic() and disable_dynamic: class_ = ModelPatcher - temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) - model = temp_model_patcher.model + if model_override is None: + if self.cached_patcher_init is None: + raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp_model_patcher.get_clone_model_override() + if model_override is None: + model_override = self.get_clone_model_override() - n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -325,13 +332,12 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) - n.backup = self.backup - n.object_patches_backup = self.object_patches_backup n.parent = self - n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights + n.backup, n.object_patches_backup, n.pinned = model_override[1] + # attachments n.attachments = {} for k in self.attachments: @@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher): del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + self.non_dynamic_delegate_model = None assert load_device is not None def is_dynamic(self): @@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher): def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: pass + def get_non_dynamic_delegate(self): + model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model) + self.non_dynamic_delegate_model = model_patcher.get_clone_model_override() + return model_patcher + + CoreModelPatcher = ModelPatcher diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 2a00ed819..13860e6a2 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -83,6 +83,16 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class IMG_TO_IMG_FLOW(CONST): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return latent_image + + def inverse_noise_scaling(self, sigma, latent): + return 1.0 - latent + class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) diff --git a/comfy/ops.py b/comfy/ops.py index 34f72ff17..6ee6075fb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,7 +19,7 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +from comfy.cli_args import args, PerformanceFeature import comfy.float import json import comfy.memory_management @@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = to_dequant(x, dtype) if not resident and lowvram_fn is not None: x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) - #FIXME: this is not accurate, we need to be sensitive to the compute dtype x = lowvram_fn(x) - if (isinstance(orig, QuantizedTensor) and - (want_requant and len(fns) == 0 or update_weight)): + if (want_requant and len(fns) == 0 or update_weight): seed = comfy.utils.string_to_seed(s.seed_key) - y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) - if want_requant and len(fns) == 0: - #The layer actually wants our freshly saved QT - x = y - elif update_weight: - y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) + if isinstance(orig, QuantizedTensor): + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + else: + y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) + if want_requant and len(fns) == 0: + x = y if update_weight: orig.copy_(y) for f in fns: @@ -296,7 +294,7 @@ class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: super().__init__(in_features, out_features, bias, device, dtype) return @@ -317,7 +315,7 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) @@ -617,7 +615,8 @@ def fp8_linear(self, input): if input.ndim != 2: return None - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True) scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1f75f2ba7..bbba09e26 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -66,6 +66,18 @@ def convert_cond(cond): out.append(temp) return out +def cond_has_hooks(cond): + for c in cond: + temp = c[1] + if "hooks" in temp: + return True + if "control" in temp: + control = temp["control"] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + return True + return False + def get_additional_models(conds, dtype): """loads additional models in conditioning""" cnets: list[ControlBase] = [] diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8be449ef7 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -946,6 +946,8 @@ class CFGGuider: def inner_set_conds(self, conds): for k in conds: + if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]): + self.model_patcher = self.model_patcher.get_non_dynamic_delegate() self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): diff --git a/comfy/sd.py b/comfy/sd.py index 69d4531e3..a9ad7c2d2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image import comfy.model_patcher import comfy.lora @@ -203,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False): if no_init: return params = target.params.copy() @@ -232,7 +233,8 @@ class CLIP: model_management.archive_model_dtypes(self.cond_stage_model) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -266,9 +268,9 @@ class CLIP: logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) self.tokenizer_options = {} - def clone(self): + def clone(self, disable_dynamic=False): n = CLIP(no_init=True) - n.patcher = self.patcher.clone() + n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic) n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx @@ -694,8 +696,9 @@ class VAE: self.latent_dim = 3 self.latent_channels = 16 self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.conv_out_channels = sd["decoder.head.2.weight"].shape[0] self.pad_channel_value = 1.0 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -1159,16 +1162,24 @@ class CLIPType(Enum): KANDINSKY5_IMAGE = 23 NEWBIE = 24 FLUX2 = 25 + LONGCAT_IMAGE = 26 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): + +def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): + clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic) + return clip.patcher + +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) if model_options.get("custom_operations", None) is None: sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic) + clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options)) + return clip class TEModel(Enum): @@ -1273,7 +1284,7 @@ def llama_detect(clip_data): return {} -def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): +def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = state_dicts class EmptyClass: @@ -1371,6 +1382,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.LONGCAT_IMAGE: + clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer @@ -1490,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic) return clip def load_gligen(ckpt_path): @@ -1535,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model: + if output_model and out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if output_clip and out[1] is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): @@ -1547,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, disable_dynamic=disable_dynamic) return model +def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False, + embedding_directory=embedding_directory, output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return clip.patcher + def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None @@ -1632,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c28be1716..4f63e8327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image from . import supported_models_base from . import latent_formats @@ -525,7 +526,8 @@ class LotusD(SD20): } unet_extra_config = { - "num_classes": 'sequential' + "num_classes": 'sequential', + "num_head_channels": 64, } def get_model(self, state_dict, prefix="", device=None): @@ -1256,6 +1258,26 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class WAN21_FlowRVS(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "flow_rvs", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) + return out + +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1667,6 +1689,37 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class LongCatImage(supported_models_base.BASE): + unet_config = { + "image_model": "flux", + "guidance_embed": False, + "vec_in_dim": None, + "context_in_dim": 3584, + "txt_ids_dims": [1, 2], + } + + sampling_settings = { + } + + unet_extra_config = {} + latent_format = latent_formats.Flux + + memory_usage_factor = 2.5 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LongCatImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index f135d74c1..853f021ae 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -328,14 +328,14 @@ class ACE15TEModel(torch.nn.Module): return getattr(self, self.lm_model).load_sd(sd) def memory_estimation_function(self, token_weight_pairs, device=None): - lm_metadata = token_weight_pairs["lm_metadata"] + lm_metadata = token_weight_pairs.get("lm_metadata", {}) constant = self.constant if comfy.model_management.should_use_bf16(device): constant *= 0.5 token_weight_pairs = token_weight_pairs.get("lm_prompt", []) num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - num_tokens += lm_metadata['min_tokens'] + num_tokens += lm_metadata.get("min_tokens", 0) return num_tokens * constant * 1024 * 1024 def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py new file mode 100644 index 000000000..882d80901 --- /dev/null +++ b/comfy/text_encoders/longcat_image.py @@ -0,0 +1,184 @@ +import re +import numbers +import torch +from comfy import sd1_clip +from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel +import logging + +logger = logging.getLogger(__name__) + +QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")] +QUOTE_PATTERN = "|".join( + [ + re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) + for q1, q2 in QUOTE_PAIRS + ] +) +WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + + +def split_quotation(prompt): + matches = WORD_INTERNAL_QUOTE_RE.findall(prompt) + mapping = [] + for i, word_src in enumerate(set(matches)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping.append((word_src, word_tgt)) + + parts = re.split(f"({QUOTE_PATTERN})", prompt) + result = [] + for part in parts: + for word_src, word_tgt in mapping: + part = part.replace(word_tgt, word_src) + if not part: + continue + is_quoted = bool(re.match(QUOTE_PATTERN, part)) + result.append((part, is_quoted)) + return result + + +class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_length = 512 + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + parts = split_quotation(text) + all_tokens = [] + for part_text, is_quoted in parts: + if is_quoted: + for char in part_text: + ids = self.tokenizer(char, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + else: + ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + + if len(all_tokens) > self.max_length: + all_tokens = all_tokens[: self.max_length] + logger.warning(f"Truncated prompt to {self.max_length} tokens") + + output = [(t, 1.0) for t in all_tokens] + # Pad to max length + self.pad_tokens(output, self.max_length - len(output)) + return [output] + + +class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="qwen25_7b", + tokenizer=LongCatImageBaseTokenizer, + ) + self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + skip_template = False + if text.startswith("<|im_start|>"): + skip_template = True + if text.startswith("<|start_header_id|>"): + skip_template = True + if text == "": + text = " " + + base_tok = getattr(self, "qwen25_7b") + if skip_template: + tokens = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + else: + prefix_ids = base_tok.tokenizer( + self.longcat_template_prefix, add_special_tokens=False + )["input_ids"] + suffix_ids = base_tok.tokenizer( + self.longcat_template_suffix, add_special_tokens=False + )["input_ids"] + + prompt_tokens = base_tok.tokenize_with_weights( + text, return_word_ids=return_word_ids, **kwargs + ) + prompt_pairs = prompt_tokens[0] + + prefix_pairs = [(t, 1.0) for t in prefix_ids] + suffix_pairs = [(t, 1.0) for t in suffix_ids] + + combined = prefix_pairs + prompt_pairs + suffix_pairs + tokens = {"qwen25_7b": [combined]} + + return tokens + + +class LongCatImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, + dtype=dtype, + name="qwen25_7b", + clip_model=Qwen25_7BVLIModel, + model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + if template_end == -1: + template_end = 0 + + suffix_start = None + for i in range(len(tok_pairs) - 1, -1, -1): + elem = tok_pairs[i][0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151645: + suffix_start = i + break + + out = out[:, template_end:] + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + if suffix_start is not None: + suffix_len = len(tok_pairs) - suffix_start + if suffix_len > 0 and out.shape[1] > suffix_len: + out = out[:, :-suffix_len] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + if extra["attention_mask"].sum() == torch.numel( + extra["attention_mask"] + ): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class LongCatImageTEModel_(LongCatImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + + return LongCatImageTEModel_ diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 64ce64f89..e86ea9f4e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,6 +6,7 @@ import comfy.text_encoders.genmo import torch import comfy.utils import math +import itertools class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) special_tokens = {"": 262144, "": 106} - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): @@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module): constant /= 2.0 token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) - num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - num_tokens = max(num_tokens, 64) + m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs]) + + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m + num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): diff --git a/comfy/utils.py b/comfy/utils.py index 5fe66ecdb..0769cef44 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,7 +29,7 @@ import itertools from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange -from comfy.cli_args import args, enables_dynamic_vram +from comfy.cli_args import args import json import time import mmap @@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - if enables_dynamic_vram(): + if comfy.memory_management.aimdo_enabled: sd, metadata = load_safetensors(ckpt) if not return_metadata: metadata = None diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 025727071..189d7d9bc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1224,9 +1224,10 @@ class BoundingBox(ComfyTypeIO): class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, - socketless: bool=True, default: dict=None, component: str=None): + socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None): super().__init__(id, display_name, optional, tooltip, None, default, socketless) self.component = component + self.force_input = force_input if default is None: self.default = {"x": 0, "y": 0, "width": 512, "height": 512} @@ -1234,6 +1235,8 @@ class BoundingBox(ComfyTypeIO): d = super().as_dict() if self.component: d["component"] = self.component + if self.force_input is not None: + d["forceInput"] = self.force_input return d diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 3304d7e76..639035fef 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -127,9 +127,15 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index cfc604aa3..6dbd5984e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -186,7 +186,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 5.0", + display_name="ByteDance Seedream 4.5 & 5.0", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index b69285be5..d83d2fc15 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -29,6 +29,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiRole, GeminiSystemInstructionContent, GeminiTextPart, + GeminiThinkingConfig, Modality, ) from comfy_api_nodes.util import ( @@ -55,6 +56,21 @@ GEMINI_IMAGE_SYS_PROMPT = ( "Prioritize generating the visual representation above any text, formatting, or conversational requests." ) +GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]), + expr=""" + ( + $m := widgets.model; + $r := widgets.resolution; + $isFlash := $contains($m, "nano banana 2"); + $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; + $prices := $isFlash ? $flashPrices : $proPrices; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, +) + class GeminiModel(str, Enum): """ @@ -229,6 +245,10 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 + elif response.modelVersion == "gemini-3.1-flash-image-preview": + input_tokens_price = 0.5 + output_text_tokens_price = 3.0 + output_image_tokens_price = 60.0 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -686,7 +706,7 @@ class GeminiImage2(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gemini-3-pro-image-preview"], + options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"], ), IO.Int.Input( "seed", @@ -750,19 +770,7 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, - price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), - expr=""" - ( - $r := widgets.resolution; - ($contains($r,"1k") or $contains($r,"2k")) - ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} - : $contains($r,"4k") - ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} - : {"type":"text","text":"Token-based"} - ) - """, - ), + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -779,6 +787,8 @@ class GeminiImage2(IO.ComfyNode): system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -815,6 +825,169 @@ class GeminiImage2(IO.ComfyNode): return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) +class GeminiNanoBanana2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["Nano Banana 2 (Gemini 3.1 Flash Image)"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + # "1:4", + # "4:1", + # "8:1", + # "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=[ + # "512px", + "1K", + "2K", + "4K", + ], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE", "IMAGE+TEXT"], + advanced=True, + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + thinking_level: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -822,6 +995,7 @@ class GeminiExtension(ComfyExtension): GeminiNode, GeminiImage, GeminiImage2, + GeminiNanoBanana2, GeminiInputFiles, ] diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 370014fb6..fcd7ef735 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -20,7 +20,7 @@ class JobStatus: # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) # 3D file extensions for preview fallback (no dedicated media_type exists) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) @@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict: normalized[node_id] = normalized_node return normalized +# Text preview truncation limit (1024 characters) to prevent preview_output bloat +TEXT_PREVIEW_MAX_LENGTH = 1024 + + +def _create_text_preview(value: str) -> dict: + """Create a text preview dict with optional truncation. + + Returns: + dict with 'content' and optionally 'truncated' flag + """ + if len(value) <= TEXT_PREVIEW_MAX_LENGTH: + return {'content': value} + return { + 'content': value[:TEXT_PREVIEW_MAX_LENGTH], + 'truncated': True + } + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: - normalized = normalize_output_item(item) - if normalized is None: - continue + if not isinstance(item, dict): + # Handle text outputs (non-dict items like strings or tuples) + normalized = normalize_output_item(item) + if normalized is None: + # Not a 3D file string — check for text preview + if media_type == 'text': + count += 1 + if preview_output is None: + if isinstance(item, tuple): + text_value = item[0] if item else '' + else: + text_value = str(item) + text_preview = _create_text_preview(text_value) + enriched = { + **text_preview, + 'nodeId': node_id, + 'mediaType': media_type + } + if fallback_preview is None: + fallback_preview = enriched + continue + # normalize_output_item returned a dict (e.g. 3D file) + item = normalized count += 1 if preview_output is not None: continue - if isinstance(normalized, dict) and is_previewable(media_type, normalized): + if is_previewable(media_type, item): enriched = { - **normalized, + **item, 'nodeId': node_id, } - if 'mediaType' not in normalized: + if 'mediaType' not in item: enriched['mediaType'] = media_type - if normalized.get('type') == 'output': + if item.get('type') == 'output': preview_output = enriched elif fallback_preview is None: fallback_preview = enriched diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 75ffb6d80..2a59a9285 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -717,11 +717,11 @@ def _render_shader_batch( gl.glUseProgram(0) for tex in input_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in output_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) for tex in ping_pong_textures: - gl.glDeleteTextures(tex) + gl.glDeleteTextures(int(tex)) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) for pp_fbo in ping_pong_fbos: @@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode): cls, image_list: list[torch.Tensor], output_batch: torch.Tensor ) -> dict[str, list]: """Build UI output with input and output images for client-side shader execution.""" - combined_inputs = torch.cat(image_list, dim=0) - input_images_ui = ui.ImageSaveHelper.save_images( - combined_inputs, - filename_prefix="GLSLShader_input", - folder_type=io.FolderType.temp, - cls=None, - compress_level=1, - ) + input_images_ui = [] + for img in image_list: + input_images_ui.extend(ui.ImageSaveHelper.save_images( + img, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + )) output_images_ui = ui.ImageSaveHelper.save_images( output_batch, diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index be7d600cd..056369e86 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -248,7 +248,7 @@ class SetClipHooks: def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: - clip = clip.clone() + clip = clip.clone(disable_dynamic=True) if apply_to_conds: clip.apply_hooks_to_conds = hooks clip.patcher.forced_hooks = hooks.clone() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 727d7d09d..4c57bb5cb 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode): @staticmethod def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) + stride_x = round(max(tile_width * 0.25, tile_width - overlap)) + stride_y = round(max(tile_width * 0.25, tile_height - overlap)) y = 0 while y < height: @@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode): ], ) - @staticmethod - def get_grid_coords(width, height, tile_width, tile_height, overlap): - coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) - - y = 0 - while y < height: - x = 0 - y_end = min(y + tile_height, height) - y_start = max(0, y_end - tile_height) - - while x < width: - x_end = min(x + tile_width, width) - x_start = max(0, x_end - tile_width) - - coords.append((x_start, y_start, x_end, y_end)) - - if x_end >= width: - break - x += stride_x - - if y_end >= height: - break - y += stride_y - - return coords - @classmethod def execute(cls, image_list, final_width, final_height, overlap): w = final_width[0] @@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode): device = first_tile.device dtype = first_tile.dtype - coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp) + coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 1eeeec011..32fe921ff 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): + """Append a guide_attention_entry to both positive and negative conditioning. + + Each entry tracks one guide reference for per-reference attention control. + Entries are derived independently from each conditioning to avoid cross-contamination. + """ + new_entry = { + "pre_filter_count": pre_filter_count, + "strength": strength, + "pixel_mask": None, + "latent_shape": latent_shape, + } + results = [] + for cond in (positive, negative): + # Read existing entries from this specific conditioning + existing = [] + for t in cond: + found = t[1].get("guide_attention_entries", None) + if found is not None: + existing = found + break + # Shallow copy and append (no deepcopy needed — entries contain + # only scalars and None for pixel_mask at this call site). + entries = [*existing, new_entry] + results.append(node_helpers.conditioning_set_values( + cond, {"guide_attention_entries": entries} + )) + return results[0], results[1] + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode): scale_factors, ) + # Track this guide for per-reference attention control. + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] + guide_latent_shape = list(t.shape[2:]) # [F, H, W] + positive, negative = _append_guide_attention_entry( + positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + ) + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] - positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) - negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + positive = node_helpers.conditioning_set_values(positive, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) + negative = node_helpers.conditioning_set_values(negative, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 6459ca8c1..a25226e6d 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro CFG", + display_name="Positive-Biased Guidance", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ @@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode): io.Model.Output(display_name="patched_model"), ], is_experimental=True, + search_aliases=[ + "mahiro", + "mahiro cfg", + "similarity-adaptive guidance", + "positive-biased cfg", + ], ) @classmethod def execute(cls, model) -> io.NodeOutput: m = model.clone() + def mahiro_normd(args): - scale: float = args['cond_scale'] - cond_p: torch.Tensor = args['cond_denoised'] - uncond_p: torch.Tensor = args['uncond_denoised'] - #naive leap + scale: float = args["cond_scale"] + cond_p: torch.Tensor = args["cond_denoised"] + uncond_p: torch.Tensor = args["uncond_denoised"] + # naive leap leap = cond_p * scale - #sim with uncond leap + # sim with uncond leap u_leap = uncond_p * scale cfg = args["denoised"] merge = (leap + cfg) / 2 normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() - simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 + simsc = 2 * (sim + 1) + wm = (simsc * cfg + (4 - simsc) * leap) / 4 return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 9601a5f76..8bf6a1afa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -52,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],), "zsnr": ("BOOLEAN", {"default": False, "advanced": True}), }} @@ -76,6 +76,8 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.X0 elif sampling == "img_to_img": sampling_type = comfy.model_sampling.IMG_TO_IMG + elif sampling == "img_to_img_flow": + sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c67245e7a..4a0f7141a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -79,7 +79,6 @@ class Blur(io.ComfyNode): node_id="ImageBlur", display_name="Image Blur", category="image/postprocessing", - essentials_category="Image Tools", inputs=[ io.Image.Input("image"), io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), @@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py new file mode 100644 index 000000000..520b4067e --- /dev/null +++ b/comfy_extras/nodes_resolution.py @@ -0,0 +1,86 @@ +from __future__ import annotations +import math +from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AspectRatio(str, Enum): + SQUARE = "1:1 (Square)" + PHOTO_H = "3:2 (Photo)" + STANDARD_H = "4:3 (Standard)" + WIDESCREEN_H = "16:9 (Widescreen)" + ULTRAWIDE_H = "21:9 (Ultrawide)" + PHOTO_V = "2:3 (Portrait Photo)" + STANDARD_V = "3:4 (Portrait Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" + + +ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { + AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_H: (16, 9), + AspectRatio.ULTRAWIDE_H: (21, 9), + AspectRatio.PHOTO_V: (2, 3), + AspectRatio.STANDARD_V: (3, 4), + AspectRatio.WIDESCREEN_V: (9, 16), +} + + +class ResolutionSelector(io.ComfyNode): + """Calculate width and height from aspect ratio and megapixel target.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionSelector", + display_name="Resolution Selector", + category="utils", + description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", + inputs=[ + io.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.SQUARE, + tooltip="The aspect ratio for the output dimensions.", + ), + io.Float.Input( + "megapixels", + default=1.0, + min=0.1, + max=16.0, + step=0.1, + tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + ), + ], + outputs=[ + io.Int.Output( + "width", tooltip="Calculated width in pixels (multiple of 8)." + ), + io.Int.Output( + "height", tooltip="Calculated height in pixels (multiple of 8)." + ), + ], + ) + + @classmethod + def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + total_pixels = megapixels * 1024 * 1024 + scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) + width = round(w_ratio * scale / 8) * 8 + height = round(h_ratio * scale / 8) * 8 + return io.NodeOutput(width, height) + + +class ResolutionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ResolutionSelector, + ] + + +async def comfy_entrypoint() -> ResolutionExtension: + return ResolutionExtension() diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py new file mode 100644 index 000000000..71441848e --- /dev/null +++ b/comfy_extras/nodes_sdpose.py @@ -0,0 +1,740 @@ +import torch +import comfy.utils +import numpy as np +import math +import colorsys +from tqdm import tqdm +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_lotus import LotusConditioning + + +def _preprocess_keypoints(kp_raw, sc_raw): + """Insert neck keypoint and remap from MMPose to OpenPose ordering. + + Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,). + Layout: + 0-17 body (18 kp, OpenPose order) + 18-23 feet (6 kp) + 24-91 face (68 kp) + 92-112 right hand (21 kp) + 113-133 left hand (21 kp) + """ + kp = np.array(kp_raw, dtype=np.float32) + sc = np.array(sc_raw, dtype=np.float32) + if len(kp) >= 17: + neck = (kp[5] + kp[6]) / 2 + neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0 + kp = np.insert(kp, 17, neck, axis=0) + sc = np.insert(sc, 17, neck_score) + mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]) + openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]) + tmp_kp, tmp_sc = kp.copy(), sc.copy() + tmp_kp[openpose_idx] = kp[mmpose_idx] + tmp_sc[openpose_idx] = sc[mmpose_idx] + kp, sc = tmp_kp, tmp_sc + return kp, sc + + +def _to_openpose_frames(all_keypoints, all_scores, height, width): + """Convert raw keypoint lists to a list of OpenPose-style frame dicts. + + Each frame dict contains: + canvas_width, canvas_height, people: list of person dicts with keys: + pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels) + foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels) + face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels) + indices 0-67: 68 face landmarks + index 68: right eye (body[14]) + index 69: left eye (body[15]) + hand_right_keypoints_2d - 21 right-hand kp (absolute pixels) + hand_left_keypoints_2d - 21 left-hand kp (absolute pixels) + """ + def _flatten(kp_slice, sc_slice): + return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist() + + frames = [] + for img_idx in range(len(all_keypoints)): + people = [] + for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]): + kp, sc = _preprocess_keypoints(kp_raw, sc_raw) + # 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15]) + face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0) + face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0) + people.append({ + "pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]), + "foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]), + "face_keypoints_2d": _flatten(face_kp, face_sc), + "hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]), + "hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]), + }) + frames.append({"canvas_width": width, "canvas_height": height, "people": people}) + return frames + + +class KeypointDraw: + """ + Pose keypoint drawing class that supports both numpy and cv2 backends. + """ + def __init__(self): + try: + import cv2 + self.draw = cv2 + except ImportError: + self.draw = self + + # Hand connections (same for both hands) + self.hand_edges = [ + [0, 1], [1, 2], [2, 3], [3, 4], # thumb + [0, 5], [5, 6], [6, 7], [7, 8], # index + [0, 9], [9, 10], [10, 11], [11, 12], # middle + [0, 13], [13, 14], [14, 15], [15, 16], # ring + [0, 17], [17, 18], [18, 19], [19, 20], # pinky + ] + + # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) + self.body_limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], + [1, 16], [16, 18] + ] + + # Colors matching DWPose + self.colors = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] + ] + + @staticmethod + def circle(canvas_np, center, radius, color, **kwargs): + """Draw a filled circle using NumPy vectorized operations.""" + cx, cy = center + h, w = canvas_np.shape[:2] + + radius_int = int(np.ceil(radius)) + + y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) + x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) + + if y_max <= y_min or x_max <= x_min: + return + + y, x = np.ogrid[y_min:y_max, x_min:x_max] + mask = (x - cx)**2 + (y - cy)**2 <= radius**2 + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): + """Draw line using Bresenham's algorithm with NumPy operations.""" + x0, y0, x1, y1 = *pt1, *pt2 + h, w = canvas_np.shape[:2] + dx, dy = abs(x1 - x0), abs(y1 - y0) + sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) + err, x, y, line_points = dx - dy, x0, y0, [] + + while True: + line_points.append((x, y)) + if x == x1 and y == y1: + break + e2 = 2 * err + if e2 > -dy: + err, x = err - dy, x + sx + if e2 < dx: + err, y = err + dx, y + sy + + if thickness > 1: + radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) + for px, py in line_points: + y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) + if y_max > y_min and x_max > x_min: + yy, xx = np.ogrid[y_min:y_max, x_min:x_max] + canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color + else: + line_points = np.array(line_points) + valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) + if (valid_points := line_points[valid]).size: + canvas_np[valid_points[:, 1], valid_points[:, 0]] = color + + @staticmethod + def fillConvexPoly(canvas_np, pts, color, **kwargs): + """Fill polygon using vectorized scanline algorithm.""" + if len(pts) < 3: + return + pts = np.array(pts, dtype=np.int32) + h, w = canvas_np.shape[:2] + y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) + if y_max <= y_min or x_max <= x_min: + return + yy, xx = np.mgrid[y_min:y_max, x_min:x_max] + mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) + + for i in range(len(pts)): + p1, p2 = pts[i], pts[(i + 1) % len(pts)] + y1, y2 = p1[1], p2[1] + if y1 == y2: + continue + if y1 > y2: + p1, p2, y1, y2 = p2, p1, p2[1], p1[1] + if not (edge_mask := (yy >= y1) & (yy < y2)).any(): + continue + mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) + + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): + """Python implementation of cv2.ellipse2Poly.""" + axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output + angle = angle % 360 + if arc_start > arc_end: + arc_start, arc_end = arc_end, arc_start + while arc_start < 0: + arc_start, arc_end = arc_start + 360, arc_end + 360 + while arc_end > 360: + arc_end, arc_start = arc_end - 360, arc_start - 360 + if arc_end - arc_start > 360: + arc_start, arc_end = 0, 360 + + angle_rad = math.radians(angle) + alpha, beta = math.cos(angle_rad), math.sin(angle_rad) + pts = [] + for i in range(arc_start, arc_end + delta, delta): + theta_rad = math.radians(min(i, arc_end)) + x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) + pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) + + unique_pts, prev_pt = [], (float('inf'), float('inf')) + for pt in pts: + if (pt_tuple := tuple(pt)) != prev_pt: + unique_pts.append(pt) + prev_pt = pt_tuple + + return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] + + def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, + draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3): + """ + Draw wholebody keypoints (134 keypoints after processing) in DWPose style. + + Expected keypoint format (after neck insertion and remapping): + - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) + - Foot: 18-23 (6 keypoints) + - Face: 24-91 (68 landmarks) + - Right hand: 92-112 (21 keypoints) + - Left hand: 113-133 (21 keypoints) + + Args: + canvas: The canvas to draw on (numpy array) + keypoints: Array of keypoint coordinates + scores: Optional confidence scores for each keypoint + threshold: Minimum confidence threshold for drawing keypoints + + Returns: + canvas: The canvas with keypoints drawn + """ + H, W, C = canvas.shape + + # Draw body limbs + if draw_body and len(keypoints) >= 18: + for i, limb in enumerate(self.body_limbSeq): + # Convert from 1-indexed to 0-indexed + idx1, idx2 = limb[0] - 1, limb[1] - 1 + + if idx1 >= 18 or idx2 >= 18: + continue + + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + Y = [keypoints[idx1][0], keypoints[idx2][0]] + X = [keypoints[idx1][1], keypoints[idx2][1]] + mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 + length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) + + if length < 1: + continue + + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) + + self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)]) + + # Draw body keypoints + if draw_body and len(keypoints) >= 18: + for i in range(18): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw foot keypoints (18-23, 6 keypoints) + if draw_feet and len(keypoints) >= 24: + for i in range(18, 24): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw right hand (92-112) + if draw_hands and len(keypoints) >= 113: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 92 + edge[0], 92 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw right hand keypoints + for i in range(92, 113): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw left hand (113-133) + if draw_hands and len(keypoints) >= 134: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 113 + edge[0], 113 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw left hand keypoints + for i in range(113, 134): + if scores is not None and i < len(scores) and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw face keypoints (24-91) - white dots only, no lines + if draw_face and len(keypoints) >= 92: + eps = 0.01 + for i in range(24, 92): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) + + return canvas + +class SDPoseDrawKeypoints(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseDrawKeypoints", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Boolean.Input("draw_body", default=True), + io.Boolean.Input("draw_hands", default=True), + io.Boolean.Input("draw_face", default=True), + io.Boolean.Input("draw_feet", default=False), + io.Int.Input("stick_width", default=4, min=1, max=10, step=1), + io.Int.Input("face_point_size", default=3, min=1, max=10, step=1), + io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput: + if not keypoints: + return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32)) + height = keypoints[0]["canvas_height"] + width = keypoints[0]["canvas_width"] + + def _parse(flat, n): + arr = np.array(flat, dtype=np.float32).reshape(n, 3) + return arr[:, :2], arr[:, 2] + + def _zeros(n): + return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32) + + pose_outputs = [] + drawer = KeypointDraw() + + for frame in tqdm(keypoints, desc="Drawing keypoints on frames"): + canvas = np.zeros((height, width, 3), dtype=np.uint8) + for person in frame["people"]: + body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18) + foot_raw = person.get("foot_keypoints_2d") + foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6) + face_kp, face_sc = _parse(person["face_keypoints_2d"], 70) + face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them + rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21) + lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21) + + kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0) + sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0) + + canvas = drawer.draw_wholebody_keypoints( + canvas, kp, sc, + threshold=score_threshold, + draw_body=draw_body, draw_feet=draw_feet, + draw_face=draw_face, draw_hands=draw_hands, + stick_width=stick_width, face_point_size=face_point_size, + ) + pose_outputs.append(canvas) + + pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) + final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 + return io.NodeOutput(final_pose_output) + +class SDPoseKeypointExtractor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseKeypointExtractor", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"], + description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints", + inputs=[ + io.Model.Input("model"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("batch_size", default=16, min=1, max=10000, step=1), + io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."), + ], + outputs=[ + io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"), + ], + ) + + @classmethod + def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput: + + height, width = image.shape[-3], image.shape[-2] + context = LotusConditioning().execute().result[0] + + # Use output_block_patch to capture the last 640-channel feature + def output_patch(h, hsp, transformer_options): + nonlocal captured_feat + if h.shape[1] == 640: # Capture the features for wholebody + captured_feat = h.clone() + return h, hsp + + model_clone = model.clone() + model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}} + + if not hasattr(model.model.diffusion_model, 'heatmap_head'): + raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.") + + head = model.model.diffusion_model.heatmap_head + total_images = image.shape[0] + captured_feat = None + + model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 + model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + + def _run_on_latent(latent_batch): + """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" + nonlocal captured_feat + captured_feat = None + _ = comfy.sample.sample( + model_clone, + noise=torch.zeros_like(latent_batch), + steps=1, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=context, negative=context, + latent_image=latent_batch, disable_noise=True, disable_pbar=True, + ) + return head(captured_feat) # keypoints_batch, scores_batch + + # all_keypoints / all_scores are lists-of-lists: + # outer index = input image index + # inner index = detected person (one per bbox, or one for full-image) + all_keypoints = [] # shape: [n_images][n_persons] + all_scores = [] # shape: [n_images][n_persons] + pbar = comfy.utils.ProgressBar(total_images) + + if bboxes is not None: + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + bboxes = [None] * total_images + # --- bbox-crop mode: one forward pass per crop ------------------------- + for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"): + img = image[img_idx:img_idx + 1] # (1, H, W, C) + # Broadcasting: if fewer bbox lists than images, repeat the last one. + img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None + + img_keypoints = [] + img_scores = [] + + if img_bboxes: + for bbox in img_bboxes: + x1 = max(0, int(bbox["x"])) + y1 = max(0, int(bbox["y"])) + x2 = min(width, int(bbox["x"] + bbox["width"])) + y2 = min(height, int(bbox["y"] + bbox["height"])) + + if x2 <= x1 or y2 <= y1: + continue + + crop_h_px, crop_w_px = y2 - y1, x2 - x1 + crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) + + # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. + scale = min(model_h / crop_h_px, model_w / crop_w_px) + scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) + pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 + + crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) + padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled + crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC + + latent_crop = vae.encode(crop_resized) + kp_batch, sc_batch = _run_on_latent(latent_crop) + kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space + + # remove padding offset, undo scale, offset to full-image coordinates. + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) + kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 + kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 + + img_keypoints.append(kp) + img_scores.append(sc) + else: + # No bboxes for this image – run on the full image + latent_img = vae.encode(img) + kp_batch, sc_batch = _run_on_latent(latent_img) + img_keypoints.append(kp_batch[0]) + img_scores.append(sc_batch[0]) + + all_keypoints.append(img_keypoints) + all_scores.append(img_scores) + pbar.update(1) + + else: # full-image mode, batched + tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") + for batch_start in range(0, total_images, batch_size): + batch_end = min(batch_start + batch_size, total_images) + latent_batch = vae.encode(image[batch_start:batch_end]) + + kp_batch, sc_batch = _run_on_latent(latent_batch) + + for kp, sc in zip(kp_batch, sc_batch): + all_keypoints.append([kp]) + all_scores.append([sc]) + tqdm_pbar.update(1) + + pbar.update(batch_end - batch_start) + + openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) + return io.NodeOutput(openpose_frames) + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + if initial_width <= 0 or initial_height <= 0: + return [0, 0, 0, 0] + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] + +class SDPoseFaceBBoxes(io.ComfyNode): + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseFaceBBoxes", + category="image/preprocessors", + search_aliases=["face bbox", "face bounding box", "pose", "keypoints"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."), + io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."), + ], + outputs=[ + io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."), + ], + ) + + @classmethod + def execute(cls, keypoints, scale, force_square) -> io.NodeOutput: + all_bboxes = [] + for frame in keypoints: + h = frame["canvas_height"] + w = frame["canvas_width"] + frame_bboxes = [] + for person in frame["people"]: + face_flat = person.get("face_keypoints_2d", []) + if not face_flat: + continue + # Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye) + face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3) + face_xy = face_arr[:, :2] # (70, 2) in absolute pixels + + kp_norm = face_xy / np.array([w, h], dtype=np.float32) + kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2) + + x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w)) + if x2 > x1 and y2 > y1: + if force_square: + bw, bh = x2 - x1, y2 - y1 + if bw != bh: + side = max(bw, bh) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + half = side // 2 + x1 = max(0, cx - half) + y1 = max(0, cy - half) + x2 = min(w, x1 + side) + y2 = min(h, y1 + side) + # Re-anchor if clamped + x1 = max(0, x2 - side) + y1 = max(0, y2 - side) + frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}) + + all_bboxes.append(frame_bboxes) + + return io.NodeOutput(all_bboxes) + + +class CropByBBoxes(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CropByBBoxes", + category="image/preprocessors", + search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"], + description="Crop and resize regions from the input image batch based on provided bounding boxes.", + inputs=[ + io.Image.Input("image"), + io.BoundingBox.Input("bboxes", force_input=True), + io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), + io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), + io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), + ], + outputs=[ + io.Image.Output(tooltip="All crops stacked into a single image batch."), + ], + ) + + @classmethod + def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: + total_frames = image.shape[0] + img_h = image.shape[1] + img_w = image.shape[2] + num_ch = image.shape[3] + + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + return io.NodeOutput(image) + + crops = [] + + for frame_idx in range(total_frames): + frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)] + if not frame_bboxes: + continue + + frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W) + + # Union all bboxes for this frame into a single crop region + x1 = min(int(b["x"]) for b in frame_bboxes) + y1 = min(int(b["y"]) for b in frame_bboxes) + x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes) + y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes) + + if padding > 0: + x1 = max(0, x1 - padding) + y1 = max(0, y1 - padding) + x2 = min(img_w, x2 + padding) + y2 = min(img_h, y2 + padding) + + x1, x2 = max(0, x1), min(img_w, x2) + y1, y2 = max(0, y1), min(img_h, y2) + + # Fallback for empty/degenerate crops + if x2 <= x1 or y2 <= y1: + fallback_size = int(min(img_h, img_w) * 0.3) + fb_x1 = max(0, (img_w - fallback_size) // 2) + fb_y1 = max(0, int(img_h * 0.1)) + fb_x2 = min(img_w, fb_x1 + fallback_size) + fb_y2 = min(img_h, fb_y1 + fallback_size) + if fb_x2 <= fb_x1 or fb_y2 <= fb_y1: + crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)) + continue + x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 + + crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + crops.append(resized) + + if not crops: + return io.NodeOutput(image) + + out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C) + return io.NodeOutput(out_images) + + +class SDPoseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SDPoseKeypointExtractor, + SDPoseDrawKeypoints, + SDPoseFaceBBoxes, + CropByBBoxes, + ] + +async def comfy_entrypoint() -> SDPoseExtension: + return SDPoseExtension() diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index db7c171a2..5c096c232 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", - essentials_category="Video Tools", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode): "start time", ], category="image/video", + essentials_category="Video Tools", inputs=[ io.Video.Input("video"), io.Float.Input( diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index effa994d1..e50bfcd2c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, + WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/comfyui_version.py b/comfyui_version.py index 6dbda1a87..6a35c6de3 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.0" +__version__ = "0.15.1" diff --git a/main.py b/main.py index 39e605deb..af701f8df 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,10 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. @@ -173,10 +177,6 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() import comfy.utils @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): diff --git a/node_helpers.py b/node_helpers.py index 4ff960ef8..d3d834516 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,5 +1,6 @@ import hashlib import torch +import logging from comfy.cli_args import args @@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False): return c +def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0): + """ + Apply values to conditioning only during [start_percent, end_percent], keeping the + original conditioning active outside that range. Respects existing per-entry ranges. + """ + if start_percent > end_percent: + logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})") + return conditioning + + EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep + c = [] + for t in conditioning: + cond_start = t[1].get("start_percent", 0.0) + cond_end = t[1].get("end_percent", 1.0) + intersect_start = max(start_percent, cond_start) + intersect_end = min(end_percent, cond_end) + + if intersect_start >= intersect_end: # no overlap: emit unchanged + c.append(t) + continue + + if intersect_start > cond_start: # part before the requested range + c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS})) + + c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end})) + + if intersect_end < cond_end: # part after the requested range + c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end})) + return c + def pillow(fn, arg): prev_value = None try: diff --git a/nodes.py b/nodes.py index e2fc20d53..5be9b16f9 100644 --- a/nodes.py +++ b/nodes.py @@ -976,7 +976,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1925,7 +1925,6 @@ class ImageInvert: class ImageBatch: SEARCH_ALIASES = ["combine images", "merge images", "stack images"] - ESSENTIALS_CATEGORY = "Image Tools" @classmethod def INPUT_TYPES(s): @@ -2436,6 +2435,7 @@ async def init_builtin_extra_nodes(): "nodes_audio_encoder.py", "nodes_rope.py", "nodes_logic.py", + "nodes_resolution.py", "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", @@ -2448,6 +2448,7 @@ async def init_builtin_extra_nodes(): "nodes_toolkit.py", "nodes_replacements.py", "nodes_nag.py", + "nodes_sdpose.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index eaf558740..1b2318273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.0" +version = "0.15.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 0064fb4ba..71019c16f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.39.16 -comfyui-workflow-templates==0.9.3 +comfyui-frontend-package==1.39.19 +comfyui-workflow-templates==0.9.4 comfyui-embedded-docs==0.4.3 torch torchsde @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.1 +comfy-aimdo>=0.2.4 requests #non essential dependencies: @@ -31,5 +31,4 @@ spandrel pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL -PyOpenGL-accelerate glfw diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 643f04e72..1d5a84b47 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -49,6 +49,12 @@ def mock_provider(mock_releases): return provider +@pytest.fixture(autouse=True) +def clear_cache(): + import utils.install_util + utils.install_util.PACKAGE_VERSIONS = {} + + def test_get_release(mock_provider, mock_releases): version = "1.0.0" release = mock_provider.get_release(version) diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py new file mode 100644 index 000000000..2551a417b --- /dev/null +++ b/tests-unit/comfy_test/model_detection_test.py @@ -0,0 +1,112 @@ +import torch + +from comfy.model_detection import detect_unet_config, model_config_from_unet_config +import comfy.supported_models + + +def _make_longcat_comfyui_sd(): + """Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + C_CTX = 3584 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, C_CTX) + sd["txt_in.bias"] = torch.empty(H) + + sd["time_in.in_layer.weight"] = torch.empty(H, 256) + sd["time_in.in_layer.bias"] = torch.empty(H) + sd["time_in.out_layer.weight"] = torch.empty(H, H) + sd["time_in.out_layer.bias"] = torch.empty(H) + + sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H) + sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H) + sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H) + sd["final_layer.linear.bias"] = torch.empty(C_IN * 4) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +def _make_flux_schnell_comfyui_sd(): + """Minimal ComfyUI-format state dict for standard Flux Schnell.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, 4096) + sd["txt_in.bias"] = torch.empty(H) + + sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +class TestModelDetection: + """Verify that first-match model detection selects the correct model + based on list ordering and unet_config specificity.""" + + def test_longcat_before_schnell_in_models_list(self): + """LongCatImage must appear before FluxSchnell in the models list.""" + models = comfy.supported_models.models + longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage") + schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell") + assert longcat_idx < schnell_idx, ( + f"LongCatImage (index {longcat_idx}) must come before " + f"FluxSchnell (index {schnell_idx}) in the models list" + ) + + def test_longcat_comfyui_detected_as_longcat(self): + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 3584 + assert unet_config["vec_in_dim"] is None + assert unet_config["guidance_embed"] is False + assert unet_config["txt_ids_dims"] == [1, 2] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "LongCatImage" + + def test_longcat_comfyui_keys_pass_through_unchanged(self): + """Pre-converted weights should not be transformed by process_unet_state_dict.""" + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + model_config = model_config_from_unet_config(unet_config, sd) + + processed = model_config.process_unet_state_dict(dict(sd)) + assert "img_in.weight" in processed + assert "txt_in.weight" in processed + assert "time_in.in_layer.weight" in processed + assert "final_layer.linear.weight" in processed + + def test_flux_schnell_comfyui_detected_as_flux_schnell(self): + sd = _make_flux_schnell_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 4096 + assert unet_config["txt_ids_dims"] == [] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "FluxSchnell" diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 83c36fe48..814af5c13 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -38,13 +38,13 @@ class TestIsPreviewable: """Unit tests for is_previewable()""" def test_previewable_media_types(self): - """Images, video, audio, 3d media types should be previewable.""" - for media_type in ['images', 'video', 'audio', '3d']: + """Images, video, audio, 3d, text media types should be previewable.""" + for media_type in ['images', 'video', 'audio', '3d', 'text']: assert is_previewable(media_type, {}) is True def test_non_previewable_media_types(self): """Other media types should not be previewable.""" - for media_type in ['latents', 'text', 'metadata', 'files']: + for media_type in ['latents', 'metadata', 'files']: assert is_previewable(media_type, {}) is False def test_3d_extensions_previewable(self): diff --git a/utils/install_util.py b/utils/install_util.py index 0f59bcf91..34489aec5 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -1,5 +1,7 @@ from pathlib import Path import sys +import logging +import re # The path to the requirements.txt file requirements_path = Path(__file__).parents[1] / "requirements.txt" @@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running: {sys.executable} {extra}-m pip install -r {requirements_path} If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. """.strip() + + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + + +PACKAGE_VERSIONS = {} +def get_required_packages_versions(): + if len(PACKAGE_VERSIONS) > 0: + return PACKAGE_VERSIONS.copy() + out = PACKAGE_VERSIONS + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().replace(">=", "==") + s = line.split("==") + if len(s) == 2: + version_str = s[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + continue + out[s[0]] = version_str + return out.copy() + except FileNotFoundError: + logging.error("requirements.txt not found.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None