diff --git a/comfy/sd.py b/comfy/sd.py index ab2718892..22747a685 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -79,7 +79,7 @@ import comfy.latent_formats import comfy.ldm.flux.redux -def load_lora_for_models(model, clip, lora, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) @@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if model is not None: new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) + if lora_metadata: + new_modelpatcher.set_attachments("lora_metadata", lora_metadata) else: k = () new_modelpatcher = None @@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if clip is not None: new_clip = clip.clone() k1 = new_clip.add_patches(loaded, strength_clip) + if lora_metadata: + new_clip.set_attachments("lora_metadata", lora_metadata) else: k1 = () new_clip = None diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 3dc1199c2..b2828d916 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -220,6 +220,14 @@ class LTXVAddGuide(io.ComfyNode): "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + io.Model.Input( + "ic_lora", + optional=True, + tooltip="Optional connection from an IC-LoRA loader. If the LoRA's safetensors metadata " + "contains 'reference_downscale_factor', the guide image will be encoded at " + "1/factor resolution and dilated back to full size (for IC-LoRAs trained on small grids). " + "Defaults to 1 (no downscale) when absent.", + ), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -229,14 +237,44 @@ class LTXVAddGuide(io.ComfyNode): ) @classmethod - def encode(cls, vae, latent_width, latent_height, images, scale_factors): + def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] - pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) + target_width = int(latent_width * width_scale_factor / latent_downscale_factor) + target_height = int(latent_height * height_scale_factor / latent_downscale_factor) + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t + @classmethod + def dilate_latent(cls, guide_latent, latent_downscale_factor): + if latent_downscale_factor <= 1: + return guide_latent, None + scale = int(latent_downscale_factor) + dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale) + dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype) + dilated[..., ::scale, ::scale] = guide_latent + dilated_mask = torch.full( + (dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]), + -1.0, device=guide_latent.device, dtype=guide_latent.dtype, + ) + dilated_mask[..., ::scale, ::scale] = 1.0 + return dilated, dilated_mask + + @classmethod + def get_reference_downscale_factor(cls, ic_lora): + if ic_lora is None: + return 1 + metadata = ic_lora.get_attachment("lora_metadata") + if not metadata: + return 1 + try: + factor = max(1, round(float(metadata.get("reference_downscale_factor", 1)))) + except (TypeError, ValueError): + factor = 1 + return factor + @classmethod def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors @@ -332,13 +370,21 @@ class LTXVAddGuide(io.ComfyNode): return latent_image, noise_mask @classmethod - def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, ic_lora=None) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape + latent_downscale_factor = cls.get_reference_downscale_factor(ic_lora) + if latent_downscale_factor > 1: + if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0: + raise ValueError( + f"Latent spatial size {latent_width}x{latent_height} must be divisible by " + f"reference_downscale_factor {latent_downscale_factor} from the ic_lora metadata." + ) + # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot time_scale_factor = scale_factors[0] num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 @@ -351,12 +397,17 @@ class LTXVAddGuide(io.ComfyNode): if not causal_fix: image = torch.cat([image[:1], image], dim=0) - image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor) if not causal_fix: t = t[:, :, 1:, :, :] image = image[1:] + guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling + guide_mask = None + if latent_downscale_factor > 1: + t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." @@ -369,12 +420,13 @@ class LTXVAddGuide(io.ComfyNode): t, strength, scale_factors, + guide_mask=guide_mask, + latent_downscale_factor=latent_downscale_factor, causal_fix=causal_fix, ) # Track this guide for per-reference attention control. pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] - guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, ) diff --git a/nodes.py b/nodes.py index 2b63f9fbb..bec3e9ca6 100644 --- a/nodes.py +++ b/nodes.py @@ -700,17 +700,19 @@ class LoraLoader: lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora = None + lora_metadata = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] + lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None else: self.loaded_lora = None if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - self.loaded_lora = (lora_path, lora) + lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True) + self.loaded_lora = (lora_path, lora, lora_metadata) - model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader):