feat: Add downscaled IC-LoRA support to LTXVAddGuide (CORE-102) (#13896)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
drozbay 2026-05-16 01:02:57 -06:00 committed by GitHub
parent 5d5a4554e1
commit d3607a8e6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 108 additions and 9 deletions

View File

@ -79,7 +79,7 @@ import comfy.latent_formats
import comfy.ldm.flux.redux import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {} key_map = {}
if model is not None: if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
if lora_metadata:
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
else: else:
k = () k = ()
new_modelpatcher = None new_modelpatcher = None
@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None: if clip is not None:
new_clip = clip.clone() new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip) k1 = new_clip.add_patches(loaded, strength_clip)
if lora_metadata:
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
else: else:
k1 = () k1 = ()
new_clip = None new_clip = None

View File

@ -14,6 +14,49 @@ from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
class GetICLoRAParameters(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetICLoRAParameters",
display_name="Get IC-LoRA Parameters",
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
category="conditioning/video_models",
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
inputs=[
io.Model.Input(
"iclora_model",
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
"from which to extract the metadata.",
),
],
outputs=[
ICLoRAParameters.Output(
"iclora_parameters",
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
"if the LoRA requires special handling of the guides.",
),
],
)
@classmethod
def execute(cls, iclora_model) -> io.NodeOutput:
metadata = iclora_model.get_attachment("lora_metadata")
factor = 1
if metadata:
try:
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
parameters = {"reference_downscale_factor": factor}
return io.NodeOutput(parameters)
class EmptyLTXVLatentVideo(io.ComfyNode): class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -220,6 +263,14 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
ICLoRAParameters.Input(
"iclora_parameters",
optional=True,
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
"Used for adjusting guide processing as required by certain IC-LoRAs "
"(eg. those with a reference_downscale_factor > 1). "
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode):
) )
@classmethod @classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors): def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
@classmethod
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
dilated[..., ::scale, ::scale] = guide_latent
dilated_mask = torch.full(
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated, dilated_mask
@classmethod
def get_reference_downscale_factor(cls, iclora_parameters):
if not iclora_parameters:
return 1
try:
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
return factor
@classmethod @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
@ -332,13 +410,21 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask return latent_image, noise_mask
@classmethod @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
if latent_downscale_factor > 1:
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
raise ValueError(
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
)
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
time_scale_factor = scale_factors[0] time_scale_factor = scale_factors[0]
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
@ -351,12 +437,17 @@ class LTXVAddGuide(io.ComfyNode):
if not causal_fix: if not causal_fix:
image = torch.cat([image[:1], image], dim=0) image = torch.cat([image[:1], image], dim=0)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
if not causal_fix: if not causal_fix:
t = t[:, :, 1:, :, :] t = t[:, :, 1:, :, :]
image = image[1:] image = image[1:]
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -369,12 +460,13 @@ class LTXVAddGuide(io.ComfyNode):
t, t,
strength, strength,
scale_factors, scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=latent_downscale_factor,
causal_fix=causal_fix, causal_fix=causal_fix,
) )
# Track this guide for per-reference attention control. # Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry( positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength, positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
) )
@ -794,6 +886,7 @@ class LtxvExtension(ComfyExtension):
ModelSamplingLTXV, ModelSamplingLTXV,
LTXVConditioning, LTXVConditioning,
LTXVScheduler, LTXVScheduler,
GetICLoRAParameters,
LTXVAddGuide, LTXVAddGuide,
LTXVPreprocess, LTXVPreprocess,
LTXVCropGuides, LTXVCropGuides,

View File

@ -700,17 +700,19 @@ class LoraLoader:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
lora_metadata = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1] lora = self.loaded_lora[1]
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
else: else:
self.loaded_lora = None self.loaded_lora = None
if lora is None: if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True) lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora, lora_metadata)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader): class LoraLoaderModelOnly(LoraLoader):