mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-18 04:57:26 +08:00
feat: Add downscaled IC-LoRA support to LTXVAddGuide (CORE-102) (#13896)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
5d5a4554e1
commit
d3607a8e6d
@ -79,7 +79,7 @@ import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
if lora_metadata:
|
||||
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
if lora_metadata:
|
||||
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
|
||||
@ -14,6 +14,49 @@ from typing_extensions import override
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
|
||||
|
||||
|
||||
class GetICLoRAParameters(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GetICLoRAParameters",
|
||||
display_name="Get IC-LoRA Parameters",
|
||||
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
|
||||
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
|
||||
category="conditioning/video_models",
|
||||
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
|
||||
inputs=[
|
||||
io.Model.Input(
|
||||
"iclora_model",
|
||||
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
|
||||
"from which to extract the metadata.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
ICLoRAParameters.Output(
|
||||
"iclora_parameters",
|
||||
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
|
||||
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
|
||||
"if the LoRA requires special handling of the guides.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, iclora_model) -> io.NodeOutput:
|
||||
metadata = iclora_model.get_attachment("lora_metadata")
|
||||
factor = 1
|
||||
if metadata:
|
||||
try:
|
||||
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
parameters = {"reference_downscale_factor": factor}
|
||||
return io.NodeOutput(parameters)
|
||||
|
||||
|
||||
class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -220,6 +263,14 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
ICLoRAParameters.Input(
|
||||
"iclora_parameters",
|
||||
optional=True,
|
||||
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
|
||||
"Used for adjusting guide processing as required by certain IC-LoRAs "
|
||||
"(eg. those with a reference_downscale_factor > 1). "
|
||||
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
|
||||
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
@classmethod
|
||||
def dilate_latent(cls, guide_latent, latent_downscale_factor):
|
||||
if latent_downscale_factor <= 1:
|
||||
return guide_latent, None
|
||||
scale = int(latent_downscale_factor)
|
||||
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
|
||||
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
|
||||
dilated[..., ::scale, ::scale] = guide_latent
|
||||
dilated_mask = torch.full(
|
||||
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
|
||||
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
|
||||
)
|
||||
dilated_mask[..., ::scale, ::scale] = 1.0
|
||||
return dilated, dilated_mask
|
||||
|
||||
@classmethod
|
||||
def get_reference_downscale_factor(cls, iclora_parameters):
|
||||
if not iclora_parameters:
|
||||
return 1
|
||||
try:
|
||||
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
|
||||
except (TypeError, ValueError):
|
||||
factor = 1
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
@ -332,13 +410,21 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return latent_image, noise_mask
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
|
||||
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
|
||||
if latent_downscale_factor > 1:
|
||||
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
|
||||
raise ValueError(
|
||||
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
|
||||
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
|
||||
)
|
||||
|
||||
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
|
||||
time_scale_factor = scale_factors[0]
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
@ -351,12 +437,17 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if not causal_fix:
|
||||
image = torch.cat([image[:1], image], dim=0)
|
||||
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
|
||||
|
||||
if not causal_fix:
|
||||
t = t[:, :, 1:, :, :]
|
||||
image = image[1:]
|
||||
|
||||
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
|
||||
guide_mask = None
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
@ -369,12 +460,13 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
t,
|
||||
strength,
|
||||
scale_factors,
|
||||
guide_mask=guide_mask,
|
||||
latent_downscale_factor=latent_downscale_factor,
|
||||
causal_fix=causal_fix,
|
||||
)
|
||||
|
||||
# Track this guide for per-reference attention control.
|
||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
)
|
||||
@ -794,6 +886,7 @@ class LtxvExtension(ComfyExtension):
|
||||
ModelSamplingLTXV,
|
||||
LTXVConditioning,
|
||||
LTXVScheduler,
|
||||
GetICLoRAParameters,
|
||||
LTXVAddGuide,
|
||||
LTXVPreprocess,
|
||||
LTXVCropGuides,
|
||||
|
||||
8
nodes.py
8
nodes.py
@ -700,17 +700,19 @@ class LoraLoader:
|
||||
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
||||
lora = None
|
||||
lora_metadata = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
lora = self.loaded_lora[1]
|
||||
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
|
||||
else:
|
||||
self.loaded_lora = None
|
||||
|
||||
if lora is None:
|
||||
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
|
||||
self.loaded_lora = (lora_path, lora, lora_metadata)
|
||||
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user