diff --git a/app/frontend_management.py b/app/frontend_management.py index bdaa85812..f753ef0de 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,7 +17,7 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired -from utils.install_util import get_missing_requirements_message, requirements_path +from utils.install_util import get_missing_requirements_message, get_required_packages_versions from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger @@ -45,25 +45,7 @@ def get_installed_frontend_version(): def get_required_frontend_version(): - """Get the required frontend version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-frontend-package=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-frontend-package not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required frontend version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-frontend-package", None) def check_frontend_version(): @@ -217,25 +199,7 @@ class FrontendManager: @classmethod def get_required_templates_version(cls) -> str: - """Get the required workflow templates version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-workflow-templates=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid templates version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-workflow-templates not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required templates version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-workflow-templates", None) @classmethod def default_frontend_path(cls) -> str: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 63daca861..13079c7bc 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -146,6 +146,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") +parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -159,7 +160,6 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -260,4 +260,4 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): - return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2f82d51da..b54f7f39a 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -214,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC): mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: - raise Exception("No sample_sigmas matched current timestep; something went wrong.") + return # substep from multi-step sampler: keep self._step from the last full step self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -858,3 +859,267 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index ea123acb4..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1621,3 +1621,118 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + + if pose_latents is None: + return main_freqs + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose + + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) + + return torch.cat([main_freqs, pose_freqs], dim=1) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if pose_latents is not None: + pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + reference_latent = None + if "reference_latent" in kwargs: + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/model_base.py b/comfy/model_base.py index 85cd30bae..1e01e9edc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1263,6 +1263,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1502,6 +1507,44 @@ class WAN21_FlowRVS(WAN21): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) self.image_to_video = image_to_video +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8a1d8ea4d..6eace4628 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 @@ -498,6 +521,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" @@ -531,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -1053,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') diff --git a/comfy/model_management.py b/comfy/model_management.py index f73613f17..0e0e96672 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -32,9 +32,6 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.torch -import comfy_aimdo.model_vbar - class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -180,6 +177,14 @@ def is_ixuca(): return True return False +def is_wsl(): + version = platform.uname().release + if version.endswith("-Microsoft"): + return True + elif version.endswith("microsoft-standard-WSL2"): + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -631,12 +636,11 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) ram_to_free = ram_required - get_free_ram() - - if current_loaded_models[i].model.is_dynamic() and for_dynamic: - #don't actually unload dynamic models for the sake of other dynamic models - #as that works on-demand. - memory_required -= current_loaded_models[i].model.loaded_size() - memory_to_free = 0 + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) @@ -1199,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): - if hasattr(weight, "_v"): - #Unexpected usage patterns. There is no reason these don't work but they - #have no testing and no callers do this. - assert r is None - assert stream is None - - cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) - - if dtype is None: - dtype = weight._model_dtype - - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) - if signature is not None: - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - v_tensor = weight._v_tensor - else: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = v_tensor - weight._v_signature = signature - #Send it over - v_tensor.copy_(weight, non_blocking=non_blocking) - return v_tensor.to(dtype=dtype) - - r = torch.empty_like(weight, dtype=dtype, device=device) - - if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: - #Offloaded casting could skip this, however it would make the quantizations - #inconsistent between loaded and offloaded weights. So force the double casting - #that would happen in regular flow to make offload deterministic. - cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) - cast_buffer.copy_(weight, non_blocking=non_blocking) - weight = cast_buffer - r.copy_(weight, non_blocking=non_blocking) - - return r - if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c9ba8096..e380e406b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -308,15 +308,22 @@ class ModelPatcher: def get_free_memory(self, device): return comfy.model_management.get_free_memory(device) - def clone(self, disable_dynamic=False): + def get_clone_model_override(self): + return self.model, (self.backup, self.object_patches_backup, self.pinned) + + def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ - model = self.model if self.is_dynamic() and disable_dynamic: class_ = ModelPatcher - temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) - model = temp_model_patcher.model + if model_override is None: + if self.cached_patcher_init is None: + raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp_model_patcher.get_clone_model_override() + if model_override is None: + model_override = self.get_clone_model_override() - n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -325,13 +332,12 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) - n.backup = self.backup - n.object_patches_backup = self.object_patches_backup n.parent = self - n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights + n.backup, n.object_patches_backup, n.pinned = model_override[1] + # attachments n.attachments = {} for k in self.attachments: @@ -1429,12 +1435,9 @@ class ModelPatcherDynamic(ModelPatcher): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): super().__init__(model, load_device, offload_device, size, weight_inplace_update) - #this is now way more dynamic and we dont support the same base model for both Dynamic - #and non-dynamic patchers. - if hasattr(self.model, "model_loaded_weight_memory"): - del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + self.non_dynamic_delegate_model = None assert load_device is not None def is_dynamic(self): @@ -1454,9 +1457,7 @@ class ModelPatcherDynamic(ModelPatcher): def loaded_size(self): vbar = self._vbar_get() - if vbar is None: - return 0 - return vbar.loaded_size() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory def get_free_memory(self, device): #NOTE: on high condition / batch counts, estimate should have already vacated @@ -1497,6 +1498,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 + self.model.model_loaded_weight_memory = 0 with self.use_ejected(): self.unpatch_hooks() @@ -1505,10 +1507,6 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We force reserve VRAM for the non comfy-weight so we dont have to deal - #with pin and unpin syncrhonization which can be expensive for small weights - #with a high layer rate (e.g. autoregressive LLMs). - #prioritize the non-comfy weights (note the order reverse). loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) @@ -1551,6 +1549,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1576,21 +1577,15 @@ class ModelPatcherDynamic(ModelPatcher): for param in params: key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) - weight.seed_key = key - set_dirty(weight, dirty) - geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype - geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - weight_size = geometry.numel() * geometry.element_size() - if vbar is not None and not hasattr(weight, "_v"): - weight._v = vbar.alloc(weight_size) - weight._model_dtype = model_dtype - allocated_size += weight_size - vbar.set_watermark_limit(allocated_size) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() move_weight_functions(m, device_to) - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid @@ -1606,7 +1601,16 @@ class ModelPatcherDynamic(ModelPatcher): assert self.load_device != torch.device("cpu") vbar = self._vbar_get() - return 0 if vbar is None else vbar.free_memory(memory_to_free) + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed def partially_unload_ram(self, ram_to_unload): loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) @@ -1633,11 +1637,6 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) - keys = list(self.backup.keys()) - for k in keys: - bk = self.backup[k] - comfy.utils.set_attr_param(self.model, k, bk.weight) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): @@ -1669,4 +1668,10 @@ class ModelPatcherDynamic(ModelPatcher): def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: pass + def get_non_dynamic_delegate(self): + model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model) + self.non_dynamic_delegate_model = model_patcher.get_clone_model_override() + return model_patcher + + CoreModelPatcher = ModelPatcher diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1f75f2ba7..bbba09e26 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -66,6 +66,18 @@ def convert_cond(cond): out.append(temp) return out +def cond_has_hooks(cond): + for c in cond: + temp = c[1] + if "hooks" in temp: + return True + if "control" in temp: + control = temp["control"] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + return True + return False + def get_additional_models(conds, dtype): """loads additional models in conditioning""" cnets: list[ControlBase] = [] diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8be449ef7 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -946,6 +946,8 @@ class CFGGuider: def inner_set_conds(self, conds): for k in conds: + if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]): + self.model_patcher = self.model_patcher.get_non_dynamic_delegate() self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): diff --git a/comfy/sd.py b/comfy/sd.py index 7713d4678..a9ad7c2d2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False): if no_init: return params = target.params.copy() @@ -233,7 +233,8 @@ class CLIP: model_management.archive_model_dtypes(self.cond_stage_model) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -267,9 +268,9 @@ class CLIP: logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) self.tokenizer_options = {} - def clone(self): + def clone(self, disable_dynamic=False): n = CLIP(no_init=True) - n.patcher = self.patcher.clone() + n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic) n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx @@ -1164,14 +1165,21 @@ class CLIPType(Enum): LONGCAT_IMAGE = 26 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): + +def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): + clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic) + return clip.patcher + +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) if model_options.get("custom_operations", None) is None: sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic) + clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options)) + return clip class TEModel(Enum): @@ -1276,7 +1284,7 @@ def llama_detect(clip_data): return {} -def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): +def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = state_dicts class EmptyClass: @@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic) return clip def load_gligen(ckpt_path): @@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model: + if output_model and out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if output_clip and out[1] is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): @@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, disable_dynamic=disable_dynamic) return model +def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False, + embedding_directory=embedding_directory, output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return clip.patcher + def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None @@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 473fbbfd4..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1268,6 +1282,16 @@ class WAN21_FlowRVS(WAN21_T2V): out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) return out +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1710,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3fe804e0b..d83d2fc15 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -789,8 +789,6 @@ class GeminiImage2(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": model = "gemini-3.1-flash-image-preview" - if response_modalities == "IMAGE+TEXT": - raise ValueError("IMAGE+TEXT is not currently available for the Nano Banana 2 model.") parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -895,7 +893,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ), IO.Combo.Input( "response_modalities", - options=["IMAGE"], + options=["IMAGE", "IMAGE+TEXT"], advanced=True, ), IO.Combo.Input( @@ -925,6 +923,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ], outputs=[ IO.Image.Output(), + IO.String.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 370014fb6..fcd7ef735 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -20,7 +20,7 @@ class JobStatus: # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) # 3D file extensions for preview fallback (no dedicated media_type exists) THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) @@ -75,6 +75,23 @@ def normalize_outputs(outputs: dict) -> dict: normalized[node_id] = normalized_node return normalized +# Text preview truncation limit (1024 characters) to prevent preview_output bloat +TEXT_PREVIEW_MAX_LENGTH = 1024 + + +def _create_text_preview(value: str) -> dict: + """Create a text preview dict with optional truncation. + + Returns: + dict with 'content' and optionally 'truncated' flag + """ + if len(value) <= TEXT_PREVIEW_MAX_LENGTH: + return {'content': value} + return { + 'content': value[:TEXT_PREVIEW_MAX_LENGTH], + 'truncated': True + } + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -221,23 +238,43 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: - normalized = normalize_output_item(item) - if normalized is None: - continue + if not isinstance(item, dict): + # Handle text outputs (non-dict items like strings or tuples) + normalized = normalize_output_item(item) + if normalized is None: + # Not a 3D file string — check for text preview + if media_type == 'text': + count += 1 + if preview_output is None: + if isinstance(item, tuple): + text_value = item[0] if item else '' + else: + text_value = str(item) + text_preview = _create_text_preview(text_value) + enriched = { + **text_preview, + 'nodeId': node_id, + 'mediaType': media_type + } + if fallback_preview is None: + fallback_preview = enriched + continue + # normalize_output_item returned a dict (e.g. 3D file) + item = normalized count += 1 if preview_output is not None: continue - if isinstance(normalized, dict) and is_previewable(media_type, normalized): + if is_previewable(media_type, item): enriched = { - **normalized, + **item, 'nodeId': node_id, } - if 'mediaType' not in normalized: + if 'mediaType' not in item: enriched['mediaType'] = media_type - if normalized.get('type') == 'output': + if item.get('type') == 'output': preview_output = enriched elif fallback_preview is None: fallback_preview = enriched diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 43df0512f..5d8d9bf6f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode): def vae_decode_audio(vae, samples, tile=None, overlap=None): if tile is not None: - audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) else: audio = vae.decode(samples["samples"]).movedim(-1, 1) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index be7d600cd..056369e86 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -248,7 +248,7 @@ class SetClipHooks: def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: - clip = clip.clone() + clip = clip.clone(disable_dynamic=True) if apply_to_conds: clip.apply_hooks_to_conds = hooks clip.patcher.forced_hooks = hooks.clone() diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 6459ca8c1..a25226e6d 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro CFG", + display_name="Positive-Biased Guidance", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ @@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode): io.Model.Output(display_name="patched_model"), ], is_experimental=True, + search_aliases=[ + "mahiro", + "mahiro cfg", + "similarity-adaptive guidance", + "positive-biased cfg", + ], ) @classmethod def execute(cls, model) -> io.NodeOutput: m = model.clone() + def mahiro_normd(args): - scale: float = args['cond_scale'] - cond_p: torch.Tensor = args['cond_denoised'] - uncond_p: torch.Tensor = args['uncond_denoised'] - #naive leap + scale: float = args["cond_scale"] + cond_p: torch.Tensor = args["cond_denoised"] + uncond_p: torch.Tensor = args["uncond_denoised"] + # naive leap leap = cond_p * scale - #sim with uncond leap + # sim with uncond leap u_leap = uncond_p * scale cfg = args["denoised"] merge = (leap + cfg) / 2 normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() - simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 + simsc = 2 * (sim + 1) + wm = (simsc * cfg + (4 - simsc) * leap) / 4 return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index effa994d1..e50bfcd2c 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,6 +1456,63 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("reference_image", optional=True), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1476,6 +1533,7 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, + WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/main.py b/main.py index 3fe8f0589..0f58d57b8 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() - if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -28,6 +23,11 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() + if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' @@ -192,7 +192,7 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram(): +if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if comfy.model_management.torch_version_numeric < (2, 8): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): diff --git a/node_helpers.py b/node_helpers.py index 4ff960ef8..d3d834516 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -1,5 +1,6 @@ import hashlib import torch +import logging from comfy.cli_args import args @@ -21,6 +22,36 @@ def conditioning_set_values(conditioning, values={}, append=False): return c +def conditioning_set_values_with_timestep_range(conditioning, values={}, start_percent=0.0, end_percent=1.0): + """ + Apply values to conditioning only during [start_percent, end_percent], keeping the + original conditioning active outside that range. Respects existing per-entry ranges. + """ + if start_percent > end_percent: + logging.warning(f"start_percent ({start_percent}) must be <= end_percent ({end_percent})") + return conditioning + + EPS = 1e-5 # the sampler gates entries with strict > / <, shift boundaries slightly to ensure only one conditioning is active per timestep + c = [] + for t in conditioning: + cond_start = t[1].get("start_percent", 0.0) + cond_end = t[1].get("end_percent", 1.0) + intersect_start = max(start_percent, cond_start) + intersect_end = min(end_percent, cond_end) + + if intersect_start >= intersect_end: # no overlap: emit unchanged + c.append(t) + continue + + if intersect_start > cond_start: # part before the requested range + c.extend(conditioning_set_values([t], {"start_percent": cond_start, "end_percent": intersect_start - EPS})) + + c.extend(conditioning_set_values([t], {**values, "start_percent": intersect_start, "end_percent": intersect_end})) + + if intersect_end < cond_end: # part after the requested range + c.extend(conditioning_set_values([t], {"start_percent": intersect_end + EPS, "end_percent": cond_end})) + return c + def pillow(fn, arg): prev_value = None try: diff --git a/requirements.txt b/requirements.txt index 1b2bd0ae6..608b0cfa6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.4 +comfyui-workflow-templates==0.9.5 comfyui-embedded-docs==0.4.3 torch torchsde @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.2 +comfy-aimdo>=0.2.4 requests #non essential dependencies: diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index 643f04e72..1d5a84b47 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -49,6 +49,12 @@ def mock_provider(mock_releases): return provider +@pytest.fixture(autouse=True) +def clear_cache(): + import utils.install_util + utils.install_util.PACKAGE_VERSIONS = {} + + def test_get_release(mock_provider, mock_releases): version = "1.0.0" release = mock_provider.get_release(version) diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 83c36fe48..814af5c13 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -38,13 +38,13 @@ class TestIsPreviewable: """Unit tests for is_previewable()""" def test_previewable_media_types(self): - """Images, video, audio, 3d media types should be previewable.""" - for media_type in ['images', 'video', 'audio', '3d']: + """Images, video, audio, 3d, text media types should be previewable.""" + for media_type in ['images', 'video', 'audio', '3d', 'text']: assert is_previewable(media_type, {}) is True def test_non_previewable_media_types(self): """Other media types should not be previewable.""" - for media_type in ['latents', 'text', 'metadata', 'files']: + for media_type in ['latents', 'metadata', 'files']: assert is_previewable(media_type, {}) is False def test_3d_extensions_previewable(self): diff --git a/utils/install_util.py b/utils/install_util.py index 0f59bcf91..34489aec5 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -1,5 +1,7 @@ from pathlib import Path import sys +import logging +import re # The path to the requirements.txt file requirements_path = Path(__file__).parents[1] / "requirements.txt" @@ -16,3 +18,34 @@ Please install the updated requirements.txt file by running: {sys.executable} {extra}-m pip install -r {requirements_path} If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem. """.strip() + + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + + +PACKAGE_VERSIONS = {} +def get_required_packages_versions(): + if len(PACKAGE_VERSIONS) > 0: + return PACKAGE_VERSIONS.copy() + out = PACKAGE_VERSIONS + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip().replace(">=", "==") + s = line.split("==") + if len(s) == 2: + version_str = s[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + continue + out[s[0]] = version_str + return out.copy() + except FileNotFoundError: + logging.error("requirements.txt not found.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None