mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Merge branch 'master' into master
This commit is contained in:
commit
66e721b1bc
@ -38,7 +38,7 @@
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud).
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
|
||||
@ -141,8 +141,7 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||
vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.")
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
|
||||
@ -22,26 +22,25 @@ class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
|
||||
"""
|
||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
||||
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
|
||||
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
|
||||
patches_per_frame: spatial patches per frame; pass None to disable compression.
|
||||
"""
|
||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
||||
|
||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
||||
self.batch_size, n, self.feature_dim = tensor.shape
|
||||
if per_frame:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = num_tokens // patches_per_frame
|
||||
|
||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
||||
# All patches in a frame are identical, so we only keep the first one
|
||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
|
||||
self.patches_per_frame = patches_per_frame
|
||||
self.num_frames = n // patches_per_frame
|
||||
# All patches in a frame are identical — keep only the first.
|
||||
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
|
||||
else:
|
||||
# Not divisible or too small - store directly without compression
|
||||
self.patches_per_frame = 1
|
||||
self.num_frames = num_tokens
|
||||
self.num_frames = n
|
||||
self.data = tensor
|
||||
|
||||
def expand(self):
|
||||
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||
"""Prepare timestep embeddings."""
|
||||
# TODO: some code reuse is needed here.
|
||||
grid_mask = kwargs.get("grid_mask", None)
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
|
||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
||||
orig_shape = kwargs.get("orig_shape")
|
||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||
v_patches_per_frame = None
|
||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||
|
||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
# Used by compute_prompt_timestep and the audio cross-attention paths.
|
||||
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
|
||||
|
||||
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
|
||||
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
|
||||
if per_frame_path:
|
||||
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
|
||||
if grid_mask is not None:
|
||||
# All-or-nothing per frame when has_spatial_mask=False.
|
||||
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
|
||||
ts_input = per_frame * self.timestep_scale_multiplier
|
||||
else:
|
||||
ts_input = timestep_scaled
|
||||
|
||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||
ts_input.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
|
||||
@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
|
||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||
|
||||
|
||||
class GuideAttentionMask:
|
||||
"""Holds the two per-group masks for LTXV guide self-attention.
|
||||
_attention_with_guide_mask splits queries into noisy and tracked-guide
|
||||
groups, so the largest mask is (1, 1, tracked_count, T).
|
||||
"""
|
||||
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
|
||||
|
||||
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
|
||||
device = tracked_weights.device
|
||||
dtype = tracked_weights.dtype
|
||||
finfo = torch.finfo(dtype)
|
||||
|
||||
pos = tracked_weights > 0
|
||||
log_w = torch.full_like(tracked_weights, finfo.min)
|
||||
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
|
||||
|
||||
self.guide_start = guide_start
|
||||
self.tracked_count = tracked_count
|
||||
|
||||
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
|
||||
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
|
||||
|
||||
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
|
||||
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||
groups, so each group needs only its own sub-mask. Avoids materializing
|
||||
the (1,1,T,T) dense mask.
|
||||
"""
|
||||
guide_start = guide_mask.guide_start
|
||||
tracked_end = guide_start + guide_mask.tracked_count
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False, # sageattn mask support is unreliable
|
||||
)
|
||||
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||
q[:, tracked_end:, :], k, v, heads,
|
||||
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -412,8 +467,10 @@ class CrossAttention(nn.Module):
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
elif isinstance(mask, GuideAttentionMask):
|
||||
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel):
|
||||
additional_args["resolved_guide_entries"] = resolved_entries
|
||||
|
||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
|
||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||
|
||||
# Total surviving guide tokens (all guides)
|
||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||
@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel):
|
||||
if not resolved_entries:
|
||||
return None
|
||||
|
||||
# Check if any attenuation is actually needed
|
||||
needs_attenuation = any(
|
||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
|
||||
needs_mask = any(
|
||||
e["strength"] != 1.0 or e.get("pixel_mask") is not None
|
||||
for e in resolved_entries
|
||||
)
|
||||
if not needs_attenuation:
|
||||
if not needs_mask:
|
||||
return None
|
||||
|
||||
# Build per-guide-token weights for all tracked guide tokens.
|
||||
@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel):
|
||||
# Concatenate per-token weights for all tracked guides
|
||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||
|
||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||
if (tracked_weights >= 1.0).all():
|
||||
# Skip when every weight is exactly 1.0 (additive bias would be 0).
|
||||
if (tracked_weights == 1.0).all():
|
||||
return None
|
||||
|
||||
# Build the mask: guide tokens are at the end of the sequence.
|
||||
# Tracked guides come first (in order), untracked follow.
|
||||
return self._build_self_attention_mask(
|
||||
total_tokens, num_guide_tokens, total_tracked,
|
||||
tracked_weights, guide_start, device, dtype,
|
||||
)
|
||||
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
|
||||
|
||||
@staticmethod
|
||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||
@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||
|
||||
@staticmethod
|
||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||
tracked_weights, guide_start, device, dtype):
|
||||
"""Build a log-space additive self-attention bias mask.
|
||||
|
||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||
|
||||
Args:
|
||||
total_tokens: Total sequence length.
|
||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||
guide_start: Index where guide tokens begin in the sequence.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
|
||||
Returns:
|
||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||
"""
|
||||
finfo = torch.finfo(dtype)
|
||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||
tracked_end = guide_start + tracked_count
|
||||
|
||||
# Convert weights to log-space bias
|
||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||
log_w = torch.full_like(w, finfo.min)
|
||||
positive_mask = w > 0
|
||||
if positive_mask.any():
|
||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||
|
||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
@ -94,6 +94,7 @@ def get_twinflow_z_image_config(state_dict):
|
||||
}
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
@ -105,6 +106,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
if lora_metadata:
|
||||
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
@ -112,6 +115,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
if lora_metadata:
|
||||
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
|
||||
@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
||||
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("length", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
@ -145,7 +145,9 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
||||
s_in = image
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s = s_in[batch_index:batch_index + length].clone()
|
||||
return IO.NodeOutput(s)
|
||||
|
||||
@ -14,6 +14,49 @@ from typing_extensions import override
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
|
||||
|
||||
|
||||
class GetICLoRAParameters(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetICLoRAParameters",
|
||||
display_name="Get IC-LoRA Parameters",
|
||||
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
|
||||
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
|
||||
category="conditioning/video_models",
|
||||
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
|
||||
inputs=[
|
||||
io.Model.Input(
|
||||
"iclora_model",
|
||||
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
|
||||
"from which to extract the metadata.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
ICLoRAParameters.Output(
|
||||
"iclora_parameters",
|
||||
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
|
||||
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
|
||||
"if the LoRA requires special handling of the guides.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, iclora_model) -> io.NodeOutput:
|
||||
metadata = iclora_model.get_attachment("lora_metadata")
|
||||
factor = 1
|
||||
if metadata:
|
||||
try:
|
||||
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
parameters = {"reference_downscale_factor": factor}
|
||||
return io.NodeOutput(parameters)
|
||||
|
||||
|
||||
class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -219,7 +262,15 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
ICLoRAParameters.Input(
|
||||
"iclora_parameters",
|
||||
optional=True,
|
||||
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
|
||||
"Used for adjusting guide processing as required by certain IC-LoRAs "
|
||||
"(eg. those with a reference_downscale_factor > 1). "
|
||||
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
|
||||
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
@classmethod
|
||||
def dilate_latent(cls, guide_latent, latent_downscale_factor):
|
||||
if latent_downscale_factor <= 1:
|
||||
return guide_latent, None
|
||||
scale = int(latent_downscale_factor)
|
||||
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
|
||||
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
|
||||
dilated[..., ::scale, ::scale] = guide_latent
|
||||
dilated_mask = torch.full(
|
||||
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
|
||||
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
|
||||
)
|
||||
dilated_mask[..., ::scale, ::scale] = 1.0
|
||||
return dilated, dilated_mask
|
||||
|
||||
@classmethod
|
||||
def get_reference_downscale_factor(cls, iclora_parameters):
|
||||
if not iclora_parameters:
|
||||
return 1
|
||||
try:
|
||||
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
@ -298,7 +376,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
else:
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||
1.0 - strength,
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
@ -318,7 +396,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
1.0 - strength,
|
||||
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
@ -332,13 +410,21 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return latent_image, noise_mask
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
|
||||
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
|
||||
if latent_downscale_factor > 1:
|
||||
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
|
||||
raise ValueError(
|
||||
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
|
||||
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
|
||||
)
|
||||
|
||||
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
|
||||
time_scale_factor = scale_factors[0]
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
@ -351,12 +437,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if not causal_fix:
|
||||
image = torch.cat([image[:1], image], dim=0)
|
||||
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
|
||||
|
||||
if not causal_fix:
|
||||
t = t[:, :, 1:, :, :]
|
||||
image = image[1:]
|
||||
|
||||
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
|
||||
guide_mask = None
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
@ -369,12 +460,13 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
guide_mask=guide_mask,
|
||||
latent_downscale_factor=latent_downscale_factor,
|
||||
causal_fix=causal_fix,
|
||||
)
|
||||
|
||||
# Track this guide for per-reference attention control.
|
||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
)
|
||||
@ -794,6 +886,7 @@ class LtxvExtension(ComfyExtension):
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
GetICLoRAParameters,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
|
||||
14
nodes.py
14
nodes.py
@ -700,17 +700,19 @@ class LoraLoader:
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
lora_metadata = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
|
||||
else:
|
||||
self.loaded_lora = None
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||
self.loaded_lora = (lora_path, lora, lora_metadata)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
@ -1221,7 +1223,7 @@ class LatentFromBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
@ -1232,7 +1234,9 @@ class LatentFromBatch:
|
||||
def frombatch(self, samples, batch_index, length):
|
||||
s = samples.copy()
|
||||
s_in = samples["samples"]
|
||||
batch_index = min(s_in.shape[0] - 1, batch_index)
|
||||
if batch_index < 0:
|
||||
batch_index += s_in.shape[0]
|
||||
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
|
||||
length = min(s_in.shape[0] - batch_index, length)
|
||||
s["samples"] = s_in[batch_index:batch_index + length].clone()
|
||||
if "noise_mask" in samples:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user