diff --git a/README.md b/README.md index 64d494f20..0eecd8a4b 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... - ComfyUI natively supports the latest open-source state of the art models. - API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. -- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. +- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud). - The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. - It integrates seamlessly into production pipelines with our API endpoints. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9dadb0093..76faed3ad 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -141,8 +141,7 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") -vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") -vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") +vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 3fb87b4a3..bc09fb77e 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -22,26 +22,25 @@ class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') - def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False): """ - tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame - patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression + tensor: [batch, num_tokens, feature_dim] (per-token, default) or + [batch, num_frames, feature_dim] (per_frame=True, already compressed). + patches_per_frame: spatial patches per frame; pass None to disable compression. """ - self.batch_size, num_tokens, self.feature_dim = tensor.shape - - # Check if compression is valid (num_tokens must be divisible by patches_per_frame) - if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.batch_size, n, self.feature_dim = tensor.shape + if per_frame: self.patches_per_frame = patches_per_frame - self.num_frames = num_tokens // patches_per_frame - - # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame - # All patches in a frame are identical, so we only keep the first one - reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) - self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + self.num_frames = n + self.data = tensor + elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0: + self.patches_per_frame = patches_per_frame + self.num_frames = n // patches_per_frame + # All patches in a frame are identical — keep only the first. + self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous() else: - # Not divisible or too small - store directly without compression self.patches_per_frame = 1 - self.num_frames = num_tokens + self.num_frames = n self.data = tensor def expand(self): @@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel): def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): """Prepare timestep embeddings.""" - # TODO: some code reuse is needed here. grid_mask = kwargs.get("grid_mask", None) - if grid_mask is not None: - timestep = timestep[:, grid_mask] - - timestep_scaled = timestep * self.timestep_scale_multiplier - - v_timestep, v_embedded_timestep = self.adaln_single( - timestep_scaled.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_dtype, - ) - - # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] - # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width orig_shape = kwargs.get("orig_shape") has_spatial_mask = kwargs.get("has_spatial_mask", None) v_patches_per_frame = None if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: - # orig_shape[3] = height, orig_shape[4] = width (in latent space) v_patches_per_frame = orig_shape[3] * orig_shape[4] - # Reshape to [batch_size, num_tokens, dim] and compress for storage - v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) - v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + # Used by compute_prompt_timestep and the audio cross-attention paths. + timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier + + # When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token + per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0 + if per_frame_path: + per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0] + if grid_mask is not None: + # All-or-nothing per frame when has_spatial_mask=False. + per_frame = per_frame[:, grid_mask[::v_patches_per_frame]] + ts_input = per_frame * self.timestep_scale_multiplier + else: + ts_input = timestep_scaled + + v_timestep, v_embedded_timestep = self.adaln_single( + ts_input.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path) v_prompt_timestep = compute_prompt_timestep( self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index bfbc08357..e0a4a0f9b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin): return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output +class GuideAttentionMask: + """Holds the two per-group masks for LTXV guide self-attention. + _attention_with_guide_mask splits queries into noisy and tracked-guide + groups, so the largest mask is (1, 1, tracked_count, T). + """ + __slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask") + + def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights): + device = tracked_weights.device + dtype = tracked_weights.dtype + finfo = torch.finfo(dtype) + + pos = tracked_weights > 0 + log_w = torch.full_like(tracked_weights, finfo.min) + log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny)) + + self.guide_start = guide_start + self.tracked_count = tracked_count + + self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype) + self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1) + + self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype) + self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1) + + +def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options): + """Apply the guide mask by partitioning Q into noisy and tracked-guide + groups, so each group needs only its own sub-mask. Avoids materializing + the (1,1,T,T) dense mask. + """ + guide_start = guide_mask.guide_start + tracked_end = guide_start + guide_mask.tracked_count + + out = torch.empty_like(q) + + if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes. + out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask, + attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, # sageattn mask support is unreliable + ) + out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask, + attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, + ) + if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes. + out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, tracked_end:, :], k, v, heads, + attn_precision=attn_precision, transformer_options=transformer_options, + ) + return out + + class CrossAttention(nn.Module): def __init__( self, @@ -412,8 +467,10 @@ class CrossAttention(nn.Module): if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) + elif isinstance(mask, GuideAttentionMask): + out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options) # Apply per-head gating if enabled if self.to_gate_logits is not None: @@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel): additional_args["resolved_guide_entries"] = resolved_entries keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] - pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + + if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs # Total surviving guide tokens (all guides) additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] @@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel): if not resolved_entries: return None - # Check if any attenuation is actually needed - needs_attenuation = any( - e["strength"] < 1.0 or e.get("pixel_mask") is not None + # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention. + needs_mask = any( + e["strength"] != 1.0 or e.get("pixel_mask") is not None for e in resolved_entries ) - if not needs_attenuation: + if not needs_mask: return None # Build per-guide-token weights for all tracked guide tokens. @@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel): # Concatenate per-token weights for all tracked guides tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) - # Check if any weight is actually < 1.0 (otherwise no attenuation needed) - if (tracked_weights >= 1.0).all(): + # Skip when every weight is exactly 1.0 (additive bias would be 0). + if (tracked_weights == 1.0).all(): return None - # Build the mask: guide tokens are at the end of the sequence. - # Tracked guides come first (in order), untracked follow. - return self._build_self_attention_mask( - total_tokens, num_guide_tokens, total_tracked, - tracked_weights, guide_start, device, dtype, - ) + return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights) @staticmethod def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): @@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel): return rearrange(latent_mask, "b 1 f h w -> b (f h w)") - @staticmethod - def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, - tracked_weights, guide_start, device, dtype): - """Build a log-space additive self-attention bias mask. - - Attenuates attention between noisy tokens and tracked guide tokens. - Untracked guide tokens (at the end of the guide portion) keep full attention. - - Args: - total_tokens: Total sequence length. - num_guide_tokens: Total guide tokens (all guides) at end of sequence. - tracked_count: Number of tracked guide tokens (first in the guide portion). - tracked_weights: (1, tracked_count) tensor, values in [0, 1]. - guide_start: Index where guide tokens begin in the sequence. - device: Target device. - dtype: Target dtype. - - Returns: - (1, 1, total_tokens, total_tokens) additive bias mask. - 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. - """ - finfo = torch.finfo(dtype) - mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) - tracked_end = guide_start + tracked_count - - # Convert weights to log-space bias - w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) - log_w = torch.full_like(w, finfo.min) - positive_mask = w > 0 - if positive_mask.any(): - log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) - - # noisy → tracked guides: each noisy row gets the same per-guide weight - mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) - # tracked guides → noisy: each guide row broadcasts its weight across noisy cols - mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) - - return mask - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0736321b3..c22705655 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1691,6 +1691,13 @@ class HiDreamO1(BaseModel): if text_input_ids is None or noise is None: return out + # handle area conds + area = kwargs.get("area", None) + if area is not None: + crop_h = min(noise.shape[-2] - area[2], area[0]) + crop_w = min(noise.shape[-1] - area[3], area[1]) + noise = torch.empty((noise.shape[0], 3, crop_h, crop_w), dtype=noise.dtype, device=noise.device) + conds = build_extra_conds( text_input_ids, noise, ref_images=kwargs.get("reference_latents", None), diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2ea14bc2c..4f9d8403e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1493,27 +1493,30 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - original_state_dict = self.model.diffusion_model.state_dict() - unet_state_dict = {} + def model_state_dict_for_saving(self, model=None, prefix=""): + if model is None: + model = self.model + + original_state_dict = model.state_dict() + output_state_dict = {} keys = list(original_state_dict) while len(keys) > 0: k = keys.pop(0) v = original_state_dict[k] op_keys = k.rsplit('.', 1) if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: - unet_state_dict[k] = v + output_state_dict[k] = v continue try: - op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + op = comfy.utils.get_attr(model, op_keys[0]) except: - unet_state_dict[k] = v + output_state_dict[k] = v continue if not op or not hasattr(op, "comfy_cast_weights") or \ (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): - unet_state_dict[k] = v + output_state_dict[k] = v continue - key = "diffusion_model." + k + key = prefix + k weight = comfy.utils.get_attr(self.model, key) if isinstance(weight, QuantizedTensor) and k in original_state_dict: qt_state_dict = weight.state_dict(k) @@ -1521,10 +1524,14 @@ class ModelPatcher: for group_key in (x for x in qt_state_dict if x in original_state_dict): if group_key in keys: keys.remove(group_key) - unet_state_dict.pop(group_key, "") - unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key]) + output_state_dict.pop(group_key, "") + output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key]) continue - unet_state_dict[k] = LazyCastingParam(self, key, weight) + output_state_dict[k] = LazyCastingParam(self, key, weight) + return output_state_dict + + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.") return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self): diff --git a/comfy/ops.py b/comfy/ops.py index 117cdd327..f9456854b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1376,6 +1376,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") + logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else "")) return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( diff --git a/comfy/sd.py b/comfy/sd.py index ab2718892..2443353a4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -79,7 +79,7 @@ import comfy.latent_formats import comfy.ldm.flux.redux -def load_lora_for_models(model, clip, lora, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) @@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if model is not None: new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) + if lora_metadata: + new_modelpatcher.set_attachments("lora_metadata", lora_metadata) else: k = () new_modelpatcher = None @@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if clip is not None: new_clip = clip.clone() k1 = new_clip.add_patches(loaded, strength_clip) + if lora_metadata: + new_clip.patcher.set_attachments("lora_metadata", lora_metadata) else: k1 = () new_clip = None @@ -419,6 +423,13 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip + def state_dict_for_saving(self): + sd_clip = self.patcher.model_state_dict_for_saving() + sd_tokenizer = self.tokenizer.state_dict() + for k in sd_tokenizer: + sd_clip[k] = sd_tokenizer[k] + return sd_clip + def load_model(self, tokens={}): memory_used = 0 if hasattr(self.cond_stage_model, "memory_estimation_function"): @@ -1904,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m load_models = [model] if clip is not None: load_models.append(clip.load_model()) - clip_sd = clip.get_sd() + clip_sd = clip.state_dict_for_saving() vae_sd = None if vae is not None: vae_sd = vae.get_sd() diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index b022009b1..416ce9d18 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -760,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs): image = kwargs.get("image", None) if image is not None and len(images) == 0: - images = [image] + images = [image[i:i + 1] for i in range(image.shape[0])] skip_template = False if text.startswith('<|im_start|>'): @@ -771,13 +771,16 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): if skip_template: llama_text = text else: - if llama_template is None: - if len(images) > 0: - llama_text = self.llama_template_images.format(text) - else: - llama_text = self.llama_template.format(text) + if llama_template is not None: + template = llama_template + elif len(images) == 0: + template = self.llama_template else: - llama_text = llama_template.format(text) + template = self.llama_template_images + if len(images) > 1: + vision_block = "<|vision_start|><|image_pad|><|vision_end|>" + template = template.replace(vision_block, vision_block * len(images), 1) + llama_text = template.format(text) if not thinking: llama_text += "\n\n" diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 60e1624f7..28dd70d4e 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -49,7 +49,7 @@ def _claude_model_inputs(): min=0.0, max=1.0, step=0.01, - tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random.", + tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.", advanced=True, ), ] @@ -208,7 +208,7 @@ class ClaudeNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) model_label = model["model"] max_tokens = model["max_tokens"] - temperature = model["temperature"] + temperature = None if model_label == "Opus 4.7" else model["temperature"] image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None] if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES: diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 3dc1199c2..fdae458e5 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -14,6 +14,49 @@ from typing_extensions import override from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy_api.latest import ComfyExtension, io +ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS") + + +class GetICLoRAParameters(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GetICLoRAParameters", + display_name="Get IC-LoRA Parameters", + description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded " + "model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).", + category="conditioning/video_models", + search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"], + inputs=[ + io.Model.Input( + "iclora_model", + tooltip="Direct output from a LoRA Loader for the specific IC-LoRA " + "from which to extract the metadata.", + ), + ], + outputs=[ + ICLoRAParameters.Output( + "iclora_parameters", + tooltip="IC-LoRA parameters extracted from the LoRA metadata " + "(eg. reference_downscale_factor). Connect to LTXVAddGuide " + "if the LoRA requires special handling of the guides.", + ), + ], + ) + + @classmethod + def execute(cls, iclora_model) -> io.NodeOutput: + metadata = iclora_model.get_attachment("lora_metadata") + factor = 1 + if metadata: + try: + factor = max(1, round(float(metadata.get("reference_downscale_factor", 1)))) + except (TypeError, ValueError): + factor = 1 + parameters = {"reference_downscale_factor": factor} + return io.NodeOutput(parameters) + + class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -219,7 +262,15 @@ class LTXVAddGuide(io.ComfyNode): "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), - io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ICLoRAParameters.Input( + "iclora_parameters", + optional=True, + tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. " + "Used for adjusting guide processing as required by certain IC-LoRAs " + "(eg. those with a reference_downscale_factor > 1). " + "When chained, each LTXVAddGuide uses only the parameters connected to it.", + ), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode): ) @classmethod - def encode(cls, vae, latent_width, latent_height, images, scale_factors): + def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] - pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) + target_width = int(latent_width * width_scale_factor / latent_downscale_factor) + target_height = int(latent_height * height_scale_factor / latent_downscale_factor) + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t + @classmethod + def dilate_latent(cls, guide_latent, latent_downscale_factor): + if latent_downscale_factor <= 1: + return guide_latent, None + scale = int(latent_downscale_factor) + dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale) + dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype) + dilated[..., ::scale, ::scale] = guide_latent + dilated_mask = torch.full( + (dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]), + -1.0, device=guide_latent.device, dtype=guide_latent.dtype, + ) + dilated_mask[..., ::scale, ::scale] = 1.0 + return dilated, dilated_mask + + @classmethod + def get_reference_downscale_factor(cls, iclora_parameters): + if not iclora_parameters: + return 1 + try: + factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1)))) + except (TypeError, ValueError): + factor = 1 + return factor + @classmethod def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors @@ -298,7 +376,7 @@ class LTXVAddGuide(io.ComfyNode): else: mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, ) @@ -318,7 +396,7 @@ class LTXVAddGuide(io.ComfyNode): mask = torch.full( (noise_mask.shape[0], 1, cond_length, 1, 1), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, ) @@ -332,13 +410,21 @@ class LTXVAddGuide(io.ComfyNode): return latent_image, noise_mask @classmethod - def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape + latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters) + if latent_downscale_factor > 1: + if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0: + raise ValueError( + f"Latent spatial size {latent_width}x{latent_height} must be divisible by " + f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters." + ) + # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot time_scale_factor = scale_factors[0] num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 @@ -351,12 +437,17 @@ class LTXVAddGuide(io.ComfyNode): if not causal_fix: image = torch.cat([image[:1], image], dim=0) - image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor) if not causal_fix: t = t[:, :, 1:, :, :] image = image[1:] + guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling + guide_mask = None + if latent_downscale_factor > 1: + t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." @@ -369,12 +460,13 @@ class LTXVAddGuide(io.ComfyNode): t, strength, scale_factors, + guide_mask=guide_mask, + latent_downscale_factor=latent_downscale_factor, causal_fix=causal_fix, ) # Track this guide for per-reference attention control. pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] - guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, ) @@ -794,6 +886,7 @@ class LtxvExtension(ComfyExtension): ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, + GetICLoRAParameters, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 96ee1a0f8..419e561ba 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -330,7 +330,7 @@ class FeatherMask(IO.ComfyNode): for x in range(right): feather_rate = (x + 1) / right - output[:, :, -x] *= feather_rate + output[:, :, -(x + 1)] *= feather_rate for y in range(top): feather_rate = (y + 1) / top @@ -338,7 +338,7 @@ class FeatherMask(IO.ComfyNode): for y in range(bottom): feather_rate = (y + 1) / bottom - output[:, -y, :] *= feather_rate + output[:, -(y + 1), :] *= feather_rate return IO.NodeOutput(output) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 5384ed531..b6b29e34a 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -276,8 +276,8 @@ class CLIPSave: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) - clip_sd = clip.get_sd() + clip.load_model() + clip_sd = clip.state_dict_for_saving() for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]: k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) diff --git a/execution.py b/execution.py index f37d0360d..4c7de2e84 100644 --- a/execution.py +++ b/execution.py @@ -626,7 +626,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." - logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) + logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type: diff --git a/nodes.py b/nodes.py index a59e8ebde..374217eea 100644 --- a/nodes.py +++ b/nodes.py @@ -700,17 +700,19 @@ class LoraLoader: lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora = None + lora_metadata = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] + lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None else: self.loaded_lora = None if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - self.loaded_lora = (lora_path, lora) + lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True) + self.loaded_lora = (lora_path, lora, lora_metadata) - model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader):