feat: Read reference_downscale_factor from IC-LoRA metadata in LTXVAddGuide

This commit is contained in:
ozbayb 2026-05-12 22:12:49 -06:00
parent 1f28908d6e
commit cbebbd75bc
3 changed files with 67 additions and 9 deletions

View File

@ -79,7 +79,7 @@ import comfy.latent_formats
import comfy.ldm.flux.redux import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {} key_map = {}
if model is not None: if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
if lora_metadata:
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
else: else:
k = () k = ()
new_modelpatcher = None new_modelpatcher = None
@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None: if clip is not None:
new_clip = clip.clone() new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip) k1 = new_clip.add_patches(loaded, strength_clip)
if lora_metadata:
new_clip.set_attachments("lora_metadata", lora_metadata)
else: else:
k1 = () k1 = ()
new_clip = None new_clip = None

View File

@ -220,6 +220,14 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Model.Input(
"ic_lora",
optional=True,
tooltip="Optional connection from an IC-LoRA loader. If the LoRA's safetensors metadata "
"contains 'reference_downscale_factor', the guide image will be encoded at "
"1/factor resolution and dilated back to full size (for IC-LoRAs trained on small grids). "
"Defaults to 1 (no downscale) when absent.",
),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -229,14 +237,44 @@ class LTXVAddGuide(io.ComfyNode):
) )
@classmethod @classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors): def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
@classmethod
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
dilated[..., ::scale, ::scale] = guide_latent
dilated_mask = torch.full(
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated, dilated_mask
@classmethod
def get_reference_downscale_factor(cls, ic_lora):
if ic_lora is None:
return 1
metadata = ic_lora.get_attachment("lora_metadata")
if not metadata:
return 1
try:
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
return factor
@classmethod @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
@ -332,13 +370,21 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask return latent_image, noise_mask
@classmethod @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, ic_lora=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
latent_downscale_factor = cls.get_reference_downscale_factor(ic_lora)
if latent_downscale_factor > 1:
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
raise ValueError(
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
f"reference_downscale_factor {latent_downscale_factor} from the ic_lora metadata."
)
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
time_scale_factor = scale_factors[0] time_scale_factor = scale_factors[0]
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
@ -351,12 +397,17 @@ class LTXVAddGuide(io.ComfyNode):
if not causal_fix: if not causal_fix:
image = torch.cat([image[:1], image], dim=0) image = torch.cat([image[:1], image], dim=0)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
if not causal_fix: if not causal_fix:
t = t[:, :, 1:, :, :] t = t[:, :, 1:, :, :]
image = image[1:] image = image[1:]
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -369,12 +420,13 @@ class LTXVAddGuide(io.ComfyNode):
t, t,
strength, strength,
scale_factors, scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=latent_downscale_factor,
causal_fix=causal_fix, causal_fix=causal_fix,
) )
# Track this guide for per-reference attention control. # Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry( positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength, positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
) )

View File

@ -700,17 +700,19 @@ class LoraLoader:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
lora_metadata = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1] lora = self.loaded_lora[1]
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
else: else:
self.loaded_lora = None self.loaded_lora = None
if lora is None: if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True) lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora, lora_metadata)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader): class LoraLoaderModelOnly(LoraLoader):