Merge branch 'master' into master

This commit is contained in:
azazeal04 2026-05-17 12:01:45 +02:00 committed by GitHub
commit 66e721b1bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 225 additions and 105 deletions

View File

@ -38,7 +38,7 @@
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models. - ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. - API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. - It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud).
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. - The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints. - It integrates seamlessly into production pipelines with our API endpoints.

View File

@ -141,8 +141,7 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he
vram_group = parser.add_mutually_exclusive_group() vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")

View File

@ -22,26 +22,25 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing.""" """Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int): def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
""" """
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame tensor: [batch, num_tokens, feature_dim] (per-token, default) or
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression [batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
""" """
self.batch_size, num_tokens, self.feature_dim = tensor.shape self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame self.num_frames = n
self.data = tensor
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
# All patches in a frame are identical, so we only keep the first one self.patches_per_frame = patches_per_frame
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) self.num_frames = n // patches_per_frame
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] # All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
else: else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1 self.patches_per_frame = 1
self.num_frames = num_tokens self.num_frames = n
self.data = tensor self.data = tensor
def expand(self): def expand(self):
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings.""" """Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None) grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape") orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None) has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4] v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage # Used by compute_prompt_timestep and the audio cross-attention paths.
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_prompt_timestep = compute_prompt_timestep( v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, tracked_end:, :], k, v, heads,
attn_precision=attn_precision, transformer_options=transformer_options,
)
return out
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__( def __init__(
self, self,
@ -412,8 +467,10 @@ class CrossAttention(nn.Module):
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled # Apply per-head gating if enabled
if self.to_gate_logits is not None: if self.to_gate_logits is not None:
@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides) # Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries: if not resolved_entries:
return None return None
# Check if any attenuation is actually needed # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_attenuation = any( needs_mask = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries for e in resolved_entries
) )
if not needs_attenuation: if not needs_mask:
return None return None
# Build per-guide-token weights for all tracked guide tokens. # Build per-guide-token weights for all tracked guide tokens.
@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides # Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed) # Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights >= 1.0).all(): if (tracked_weights == 1.0).all():
return None return None
# Build the mask: guide tokens are at the end of the sequence. return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod @staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)") return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV.""" """Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})

View File

@ -94,6 +94,7 @@ def get_twinflow_z_image_config(state_dict):
} }
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {} key_map = {}
if model is not None: if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@ -105,6 +106,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if model is not None: if model is not None:
new_modelpatcher = model.clone() new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
if lora_metadata:
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
else: else:
k = () k = ()
new_modelpatcher = None new_modelpatcher = None
@ -112,6 +115,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None: if clip is not None:
new_clip = clip.clone() new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip) k1 = new_clip.add_patches(loaded, strength_clip)
if lora_metadata:
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
else: else:
k1 = () k1 = ()
new_clip = None new_clip = None

View File

@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
category="image/batch", category="image/batch",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Int.Input("batch_index", default=0, min=0, max=4095), IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
IO.Int.Input("length", default=1, min=1, max=4096), IO.Int.Input("length", default=1, min=1, max=4096),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
@ -145,7 +145,9 @@ class ImageFromBatch(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput: def execute(cls, image, batch_index, length) -> IO.NodeOutput:
s_in = image s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index) if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
length = min(s_in.shape[0] - batch_index, length) length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone() s = s_in[batch_index:batch_index + length].clone()
return IO.NodeOutput(s) return IO.NodeOutput(s)

View File

@ -14,6 +14,49 @@ from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
class GetICLoRAParameters(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetICLoRAParameters",
display_name="Get IC-LoRA Parameters",
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
category="conditioning/video_models",
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
inputs=[
io.Model.Input(
"iclora_model",
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
"from which to extract the metadata.",
),
],
outputs=[
ICLoRAParameters.Output(
"iclora_parameters",
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
"if the LoRA requires special handling of the guides.",
),
],
)
@classmethod
def execute(cls, iclora_model) -> io.NodeOutput:
metadata = iclora_model.get_attachment("lora_metadata")
factor = 1
if metadata:
try:
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
parameters = {"reference_downscale_factor": factor}
return io.NodeOutput(parameters)
class EmptyLTXVLatentVideo(io.ComfyNode): class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -219,7 +262,15 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
ICLoRAParameters.Input(
"iclora_parameters",
optional=True,
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
"Used for adjusting guide processing as required by certain IC-LoRAs "
"(eg. those with a reference_downscale_factor > 1). "
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode):
) )
@classmethod @classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors): def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
@classmethod
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
dilated[..., ::scale, ::scale] = guide_latent
dilated_mask = torch.full(
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated, dilated_mask
@classmethod
def get_reference_downscale_factor(cls, iclora_parameters):
if not iclora_parameters:
return 1
try:
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
return factor
@classmethod @classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors time_scale_factor, _, _ = scale_factors
@ -298,7 +376,7 @@ class LTXVAddGuide(io.ComfyNode):
else: else:
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )
@ -318,7 +396,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1), (noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )
@ -332,13 +410,21 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask return latent_image, noise_mask
@classmethod @classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula scale_factors = vae.downscale_index_formula
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = get_noise_mask(latent) noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape _, _, latent_length, latent_height, latent_width = latent_image.shape
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
if latent_downscale_factor > 1:
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
raise ValueError(
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
)
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
time_scale_factor = scale_factors[0] time_scale_factor = scale_factors[0]
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
@ -351,12 +437,17 @@ class LTXVAddGuide(io.ComfyNode):
if not causal_fix: if not causal_fix:
image = torch.cat([image[:1], image], dim=0) image = torch.cat([image[:1], image], dim=0)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
if not causal_fix: if not causal_fix:
t = t[:, :, 1:, :, :] t = t[:, :, 1:, :, :]
image = image[1:] image = image[1:]
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -369,12 +460,13 @@ class LTXVAddGuide(io.ComfyNode):
t, t,
strength, strength,
scale_factors, scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=latent_downscale_factor,
causal_fix=causal_fix, causal_fix=causal_fix,
) )
# Track this guide for per-reference attention control. # Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry( positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength, positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
) )
@ -794,6 +886,7 @@ class LtxvExtension(ComfyExtension):
ModelSamplingLTXV, ModelSamplingLTXV,
LTXVConditioning, LTXVConditioning,
LTXVScheduler, LTXVScheduler,
GetICLoRAParameters,
LTXVAddGuide, LTXVAddGuide,
LTXVPreprocess, LTXVPreprocess,
LTXVCropGuides, LTXVCropGuides,

View File

@ -700,17 +700,19 @@ class LoraLoader:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None lora = None
lora_metadata = None
if self.loaded_lora is not None: if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path: if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1] lora = self.loaded_lora[1]
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
else: else:
self.loaded_lora = None self.loaded_lora = None
if lora is None: if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True) lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora, lora_metadata)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader): class LoraLoaderModelOnly(LoraLoader):
@ -1221,7 +1223,7 @@ class LatentFromBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), "batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}), "length": ("INT", {"default": 1, "min": 1, "max": 64}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
@ -1232,7 +1234,9 @@ class LatentFromBatch:
def frombatch(self, samples, batch_index, length): def frombatch(self, samples, batch_index, length):
s = samples.copy() s = samples.copy()
s_in = samples["samples"] s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index) if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
length = min(s_in.shape[0] - batch_index, length) length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone() s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples: if "noise_mask" in samples: