This commit is contained in:
drozbay 2026-05-08 04:38:41 +03:00 committed by GitHub
commit e39ef3a0b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -224,6 +224,15 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input(
"latent_downscale_factor",
default=1.0,
min=1.0,
max=10.0,
step=1.0,
tooltip="Encodes the guide image at a fraction of the target size, then dilates back to full size. "
"1 = original size, 2 = half size, 3 = third, etc. Used for IC-LoRA on small grids.",
),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -233,10 +242,12 @@ class LTXVAddGuide(io.ComfyNode):
) )
@classmethod @classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors): def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
@ -336,13 +347,48 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask return latent_image, noise_mask
@classmethod @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
samples = guide_latent
dilated_shape = samples.shape[:3] + (
samples.shape[3] * scale,
samples.shape[4] * scale,
)
dilated_samples = torch.zeros(dilated_shape, device=samples.device, dtype=samples.dtype)
dilated_samples[..., ::scale, ::scale] = samples
dilated_mask = torch.full(
(dilated_samples.shape[0], 1, dilated_samples.shape[2], dilated_samples.shape[3], dilated_samples.shape[4]),
-1.0, device=samples.device, dtype=samples.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated_samples, dilated_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, latent_downscale_factor=1.0) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
latent_downscale_factor = max(1, round(latent_downscale_factor))
if latent_downscale_factor > 1:
if latent_width % int(latent_downscale_factor) != 0 or latent_height % int(latent_downscale_factor) != 0:
raise ValueError(f"Latent spatial size {latent_width}x{latent_height} must be divisible by latent_downscale_factor {int(latent_downscale_factor)}")
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -356,11 +402,11 @@ class LTXVAddGuide(io.ComfyNode):
t, t,
strength, strength,
scale_factors, scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=int(latent_downscale_factor),
) )
# Track this guide for per-reference attention control. # Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry( positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength, positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
) )