From b8bd32c3eda319182265fe8fff84760f13d556e6 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:51:39 -0600 Subject: [PATCH] Add latent_downscale_factor to LTXVAddGuide for IC-LoRA on small grids --- comfy_extras/nodes_lt.py | 58 +++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 19d8a387f..cd3bbf6b7 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -224,6 +224,15 @@ class LTXVAddGuide(io.ComfyNode): "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input( + "latent_downscale_factor", + default=1.0, + min=1.0, + max=10.0, + step=1.0, + tooltip="Encodes the guide image at a fraction of the target size, then dilates back to full size. " + "1 = original size, 2 = half size, 3 = third, etc. Used for IC-LoRA on small grids.", + ), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -233,10 +242,12 @@ class LTXVAddGuide(io.ComfyNode): ) @classmethod - def encode(cls, vae, latent_width, latent_height, images, scale_factors): + def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] - pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) + target_width = int(latent_width * width_scale_factor / latent_downscale_factor) + target_height = int(latent_height * height_scale_factor / latent_downscale_factor) + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="disabled").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t @@ -336,13 +347,48 @@ class LTXVAddGuide(io.ComfyNode): return latent_image, noise_mask @classmethod - def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: + def dilate_latent(cls, guide_latent, latent_downscale_factor): + if latent_downscale_factor <= 1: + return guide_latent, None + + scale = int(latent_downscale_factor) + samples = guide_latent + dilated_shape = samples.shape[:3] + ( + samples.shape[3] * scale, + samples.shape[4] * scale, + ) + dilated_samples = torch.zeros(dilated_shape, device=samples.device, dtype=samples.dtype) + dilated_samples[..., ::scale, ::scale] = samples + + dilated_mask = torch.full( + (dilated_samples.shape[0], 1, dilated_samples.shape[2], dilated_samples.shape[3], dilated_samples.shape[4]), + -1.0, device=samples.device, dtype=samples.dtype, + ) + dilated_mask[..., ::scale, ::scale] = 1.0 + return dilated_samples, dilated_mask + + @classmethod + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, latent_downscale_factor=1.0) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape - image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) + + latent_downscale_factor = max(1, round(latent_downscale_factor)) + if latent_downscale_factor > 1: + if latent_width % int(latent_downscale_factor) != 0 or latent_height % int(latent_downscale_factor) != 0: + raise ValueError(f"Latent spatial size {latent_width}x{latent_height} must be divisible by latent_downscale_factor {int(latent_downscale_factor)}") + + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor) + + guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial mask downsampling + + guide_mask = None + if latent_downscale_factor > 1: + t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) + + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." @@ -356,11 +402,11 @@ class LTXVAddGuide(io.ComfyNode): t, strength, scale_factors, + guide_mask=guide_mask, + latent_downscale_factor=int(latent_downscale_factor), ) # Track this guide for per-reference attention control. - pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] - guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, )