Add latent_downscale_factor to LTXVAddGuide for IC-LoRA on small grids

This commit is contained in:
ozbayb 2026-04-29 16:51:39 -06:00
parent 0e25a6936e
commit b8bd32c3ed

View File

@ -224,6 +224,15 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input(
"latent_downscale_factor",
default=1.0,
min=1.0,
max=10.0,
step=1.0,
tooltip="Encodes the guide image at a fraction of the target size, then dilates back to full size. "
"1 = original size, 2 = half size, 3 = third, etc. Used for IC-LoRA on small grids.",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -233,10 +242,12 @@ class LTXVAddGuide(io.ComfyNode):
)
@classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
@ -336,13 +347,48 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
samples = guide_latent
dilated_shape = samples.shape[:3] + (
samples.shape[3] * scale,
samples.shape[4] * scale,
)
dilated_samples = torch.zeros(dilated_shape, device=samples.device, dtype=samples.dtype)
dilated_samples[..., ::scale, ::scale] = samples
dilated_mask = torch.full(
(dilated_samples.shape[0], 1, dilated_samples.shape[2], dilated_samples.shape[3], dilated_samples.shape[4]),
-1.0, device=samples.device, dtype=samples.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated_samples, dilated_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, latent_downscale_factor=1.0) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
latent_downscale_factor = max(1, round(latent_downscale_factor))
if latent_downscale_factor > 1:
if latent_width % int(latent_downscale_factor) != 0 or latent_height % int(latent_downscale_factor) != 0:
raise ValueError(f"Latent spatial size {latent_width}x{latent_height} must be divisible by latent_downscale_factor {int(latent_downscale_factor)}")
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -356,11 +402,11 @@ class LTXVAddGuide(io.ComfyNode):
t,
strength,
scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=int(latent_downscale_factor),
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)