mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 17:22:36 +08:00
Merge 822cce0c72 into ef8f25601a
This commit is contained in:
commit
e39ef3a0b7
@ -224,6 +224,15 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||||
),
|
),
|
||||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
io.Float.Input(
|
||||||
|
"latent_downscale_factor",
|
||||||
|
default=1.0,
|
||||||
|
min=1.0,
|
||||||
|
max=10.0,
|
||||||
|
step=1.0,
|
||||||
|
tooltip="Encodes the guide image at a fraction of the target size, then dilates back to full size. "
|
||||||
|
"1 = original size, 2 = half size, 3 = third, etc. Used for IC-LoRA on small grids.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
@ -233,10 +242,12 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
|
||||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
|
||||||
|
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
|
||||||
|
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="disabled").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
return encode_pixels, t
|
return encode_pixels, t
|
||||||
@ -336,13 +347,48 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
return latent_image, noise_mask
|
return latent_image, noise_mask
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
|
def dilate_latent(cls, guide_latent, latent_downscale_factor):
|
||||||
|
if latent_downscale_factor <= 1:
|
||||||
|
return guide_latent, None
|
||||||
|
|
||||||
|
scale = int(latent_downscale_factor)
|
||||||
|
samples = guide_latent
|
||||||
|
dilated_shape = samples.shape[:3] + (
|
||||||
|
samples.shape[3] * scale,
|
||||||
|
samples.shape[4] * scale,
|
||||||
|
)
|
||||||
|
dilated_samples = torch.zeros(dilated_shape, device=samples.device, dtype=samples.dtype)
|
||||||
|
dilated_samples[..., ::scale, ::scale] = samples
|
||||||
|
|
||||||
|
dilated_mask = torch.full(
|
||||||
|
(dilated_samples.shape[0], 1, dilated_samples.shape[2], dilated_samples.shape[3], dilated_samples.shape[4]),
|
||||||
|
-1.0, device=samples.device, dtype=samples.dtype,
|
||||||
|
)
|
||||||
|
dilated_mask[..., ::scale, ::scale] = 1.0
|
||||||
|
return dilated_samples, dilated_mask
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, latent_downscale_factor=1.0) -> io.NodeOutput:
|
||||||
scale_factors = vae.downscale_index_formula
|
scale_factors = vae.downscale_index_formula
|
||||||
latent_image = latent["samples"]
|
latent_image = latent["samples"]
|
||||||
noise_mask = get_noise_mask(latent)
|
noise_mask = get_noise_mask(latent)
|
||||||
|
|
||||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
|
|
||||||
|
latent_downscale_factor = max(1, round(latent_downscale_factor))
|
||||||
|
if latent_downscale_factor > 1:
|
||||||
|
if latent_width % int(latent_downscale_factor) != 0 or latent_height % int(latent_downscale_factor) != 0:
|
||||||
|
raise ValueError(f"Latent spatial size {latent_width}x{latent_height} must be divisible by latent_downscale_factor {int(latent_downscale_factor)}")
|
||||||
|
|
||||||
|
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
|
||||||
|
|
||||||
|
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial mask downsampling
|
||||||
|
|
||||||
|
guide_mask = None
|
||||||
|
if latent_downscale_factor > 1:
|
||||||
|
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||||
|
|
||||||
|
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||||
|
|
||||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||||
@ -356,11 +402,11 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
t,
|
t,
|
||||||
strength,
|
strength,
|
||||||
scale_factors,
|
scale_factors,
|
||||||
|
guide_mask=guide_mask,
|
||||||
|
latent_downscale_factor=int(latent_downscale_factor),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Track this guide for per-reference attention control.
|
# Track this guide for per-reference attention control.
|
||||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
|
||||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
|
||||||
positive, negative = _append_guide_attention_entry(
|
positive, negative = _append_guide_attention_entry(
|
||||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user