Merge branch 'master' into kosinkadink/batch-nodes-min-1-required
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Alexis Rolland 2026-05-19 09:45:40 +08:00 committed by GitHub
commit 8f8b2a0ff3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 551 additions and 41 deletions

View File

@ -38,7 +38,7 @@
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud).
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.

View File

@ -44,7 +44,14 @@ class BackgroundRemovalModel():
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
if pixel_values.shape[0] > 1:
out = torch.cat([
self.model(pixel_values=pixel_values[i:i+1])
for i in range(pixel_values.shape[0])
], dim=0)
else:
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())

View File

@ -1691,6 +1691,13 @@ class HiDreamO1(BaseModel):
if text_input_ids is None or noise is None:
return out
# handle area conds
area = kwargs.get("area", None)
if area is not None:
crop_h = min(noise.shape[-2] - area[2], area[0])
crop_w = min(noise.shape[-1] - area[3], area[1])
noise = torch.empty((noise.shape[0], 3, crop_h, crop_w), dtype=noise.dtype, device=noise.device)
conds = build_extra_conds(
text_input_ids, noise,
ref_images=kwargs.get("reference_latents", None),

View File

@ -1493,27 +1493,30 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
original_state_dict = self.model.diffusion_model.state_dict()
unet_state_dict = {}
def model_state_dict_for_saving(self, model=None, prefix=""):
if model is None:
model = self.model
original_state_dict = model.state_dict()
output_state_dict = {}
keys = list(original_state_dict)
while len(keys) > 0:
k = keys.pop(0)
v = original_state_dict[k]
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
op = comfy.utils.get_attr(model, op_keys[0])
except:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
unet_state_dict[k] = v
output_state_dict[k] = v
continue
key = "diffusion_model." + k
key = prefix + k
weight = comfy.utils.get_attr(self.model, key)
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
qt_state_dict = weight.state_dict(k)
@ -1521,10 +1524,14 @@ class ModelPatcher:
for group_key in (x for x in qt_state_dict if x in original_state_dict):
if group_key in keys:
keys.remove(group_key)
unet_state_dict.pop(group_key, "")
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
output_state_dict.pop(group_key, "")
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
continue
unet_state_dict[k] = LazyCastingParam(self, key, weight)
output_state_dict[k] = LazyCastingParam(self, key, weight)
return output_state_dict
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):

View File

@ -1376,6 +1376,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else ""))
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
if (

View File

@ -79,7 +79,7 @@ import comfy.latent_formats
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
if lora_metadata:
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
else:
k = ()
new_modelpatcher = None
@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
if lora_metadata:
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
else:
k1 = ()
new_clip = None
@ -419,6 +423,13 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def state_dict_for_saving(self):
sd_clip = self.patcher.model_state_dict_for_saving()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
@ -1904,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
clip_sd = clip.state_dict_for_saving()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()

View File

@ -760,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
images = [image[i:i + 1] for i in range(image.shape[0])]
skip_template = False
if text.startswith('<|im_start|>'):
@ -771,13 +771,16 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
if llama_template is not None:
template = llama_template
elif len(images) == 0:
template = self.llama_template
else:
llama_text = llama_template.format(text)
template = self.llama_template_images
if len(images) > 1:
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
template = template.replace(vision_block, vision_block * len(images), 1)
llama_text = template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"

View File

@ -0,0 +1,101 @@
"""Pydantic models for BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
"""
from typing import Literal
from pydantic import BaseModel, Field
class BytePlusInputText(BaseModel):
type: Literal["input_text"] = "input_text"
text: str = Field(...)
class BytePlusInputImage(BaseModel):
type: Literal["input_image"] = "input_image"
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
detail: str = Field("auto", description="One of high, low, auto")
class BytePlusInputVideo(BaseModel):
type: Literal["input_video"] = "input_video"
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
fps: float | None = Field(None, ge=0.2, le=5.0)
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
class BytePlusInputMessage(BaseModel):
type: Literal["message"] = "message"
role: str = Field(..., description="One of user, system, assistant, developer")
content: list[BytePlusMessageContent] = Field(...)
class BytePlusResponseCreateRequest(BaseModel):
model: str = Field(...)
input: list[BytePlusInputMessage] = Field(...)
instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None, ge=1)
temperature: float | None = Field(None, ge=0.0, le=2.0)
store: bool | None = Field(False)
stream: bool | None = Field(False)
class BytePlusOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str = Field(...)
class BytePlusOutputRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str = Field(...)
class BytePlusOutputContent(BaseModel):
type: str = Field(...)
text: str | None = Field(None)
refusal: str | None = Field(None)
class BytePlusOutputMessage(BaseModel):
type: str = Field(...)
id: str | None = Field(None)
role: str | None = Field(None)
status: str | None = Field(None)
content: list[BytePlusOutputContent] | None = Field(None)
class BytePlusInputTokensDetails(BaseModel):
cached_tokens: int | None = Field(None)
class BytePlusOutputTokensDetails(BaseModel):
reasoning_tokens: int | None = Field(None)
class BytePlusResponseUsage(BaseModel):
input_tokens: int | None = Field(None)
output_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
class BytePlusResponseError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class BytePlusResponseObject(BaseModel):
id: str | None = Field(None)
object: str | None = Field(None)
created_at: int | None = Field(None)
model: str | None = Field(None)
status: str | None = Field(None)
error: BytePlusResponseError | None = Field(None)
output: list[BytePlusOutputMessage] | None = Field(None)
usage: BytePlusResponseUsage | None = Field(None)

View File

@ -49,7 +49,7 @@ def _claude_model_inputs():
min=0.0,
max=1.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random.",
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
advanced=True,
),
]
@ -208,7 +208,7 @@ class ClaudeNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
max_tokens = model["max_tokens"]
temperature = model["temperature"]
temperature = None if model_label == "Opus 4.7" else model["temperature"]
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:

View File

@ -0,0 +1,271 @@
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
"""
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_llm import (
BytePlusInputImage,
BytePlusInputMessage,
BytePlusInputText,
BytePlusInputVideo,
BytePlusMessageContent,
BytePlusResponseCreateRequest,
BytePlusResponseObject,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
SEED_MAX_IMAGES = 20
SEED_MAX_VIDEOS = 4
SEED_MODELS: dict[str, str] = {
"Seed 2.0 Pro": "seed-2-0-pro-260328",
"Seed 2.0 Lite": "seed-2-0-lite-260228",
"Seed 2.0 Mini": "seed-2-0-mini-260215",
}
# USD per 1M tokens: (input, cache_hit_input, output)
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
}
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_images + 1)],
min=0,
),
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, max_videos + 1)],
min=0,
),
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
advanced=True,
),
]
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
"""Compute approximate USD price from response usage."""
if not response.usage:
return None
rates = _SEED_PRICES_PER_MILLION.get(model_id)
if rates is None:
return None
input_rate, cache_hit_rate, output_rate = rates
input_tokens = response.usage.input_tokens or 0
output_tokens = response.usage.output_tokens or 0
cached = 0
if response.usage.input_tokens_details:
cached = response.usage.input_tokens_details.cached_tokens or 0
fresh_input = max(0, input_tokens - cached)
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
return total / 1_000_000.0
def _get_text_from_response(response: BytePlusResponseObject) -> str:
"""Extract concatenated text from all assistant message output_text blocks."""
if not response.output:
return ""
chunks: list[str] = []
for item in response.output:
if item.type != "message" or not item.content:
continue
for block in item.content:
if block.type == "output_text" and block.text:
chunks.append(block.text)
elif block.type == "refusal" and block.refusal:
raise ValueError(f"Model refused to respond: {block.refusal}")
return "\n".join(chunks)
async def _build_image_content_blocks(
cls: type[IO.ComfyNode],
image_tensors: list[Input.Image],
) -> list[BytePlusInputImage]:
urls = await upload_images_to_comfyapi(
cls,
image_tensors,
max_images=SEED_MAX_IMAGES,
wait_label="Uploading reference images",
)
return [BytePlusInputImage(image_url=url) for url in urls]
async def _build_video_content_blocks(
cls: type[IO.ComfyNode],
videos: list[Input.Video],
) -> list[BytePlusInputVideo]:
blocks: list[BytePlusInputVideo] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(BytePlusInputVideo(video_url=url))
return blocks
class ByteDanceSeedNode(IO.ComfyNode):
"""Generate text responses from a ByteDance Seed 2.0 model."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="api node/text/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
tooltip="The Seed model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "mini") ? {
"type": "list_usd",
"usd": [0.00025, 0.0009],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.0003, 0.002],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "pro") ? {
"type": "list_usd",
"usd": [0.0005, 0.003],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
temperature = model["temperature"]
model_id = SEED_MODELS[model_label]
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if len(video_inputs) > SEED_MAX_VIDEOS:
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
content: list[BytePlusMessageContent] = []
if image_tensors:
content.extend(await _build_image_content_blocks(cls, image_tensors))
if video_inputs:
content.extend(await _build_video_content_blocks(cls, video_inputs))
content.append(BytePlusInputText(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
response_model=BytePlusResponseObject,
data=BytePlusResponseCreateRequest(
model=model_id,
input=[BytePlusInputMessage(role="user", content=content)],
instructions=system_prompt or None,
temperature=temperature,
store=False,
stream=False,
),
price_extractor=lambda r: _calculate_price(model_id, r),
)
if response.error:
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
result = _get_text_from_response(response)
if not result:
raise ValueError("Empty response from Seed model.")
return IO.NodeOutput(result)
class ByteDanceLLMExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ByteDanceSeedNode]
async def comfy_entrypoint() -> ByteDanceLLMExtension:
return ByteDanceLLMExtension()

View File

@ -14,6 +14,49 @@ from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
class GetICLoRAParameters(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetICLoRAParameters",
display_name="Get IC-LoRA Parameters",
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
category="conditioning/video_models",
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
inputs=[
io.Model.Input(
"iclora_model",
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
"from which to extract the metadata.",
),
],
outputs=[
ICLoRAParameters.Output(
"iclora_parameters",
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
"if the LoRA requires special handling of the guides.",
),
],
)
@classmethod
def execute(cls, iclora_model) -> io.NodeOutput:
metadata = iclora_model.get_attachment("lora_metadata")
factor = 1
if metadata:
try:
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
parameters = {"reference_downscale_factor": factor}
return io.NodeOutput(parameters)
class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -132,7 +175,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
@ -141,9 +184,10 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W)
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
@ -153,8 +197,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
# Shallow copy only and append (pixel_mask is never mutated).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
@ -220,6 +263,20 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Mask.Input(
"attention_mask",
optional=True,
tooltip="Optional pixel-space spatial mask. Controls per-region "
"conditioning influence via self-attention, multiplied by strength.",
),
ICLoRAParameters.Input(
"iclora_parameters",
optional=True,
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
"Used for adjusting guide processing as required by certain IC-LoRAs "
"(eg. those with a reference_downscale_factor > 1). "
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -229,14 +286,41 @@ class LTXVAddGuide(io.ComfyNode):
)
@classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
@classmethod
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
dilated[..., ::scale, ::scale] = guide_latent
dilated_mask = torch.full(
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated, dilated_mask
@classmethod
def get_reference_downscale_factor(cls, iclora_parameters):
if not iclora_parameters:
return 1
try:
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
return factor
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
@ -332,13 +416,21 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
if latent_downscale_factor > 1:
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
raise ValueError(
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
)
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
time_scale_factor = scale_factors[0]
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
@ -351,12 +443,17 @@ class LTXVAddGuide(io.ComfyNode):
if not causal_fix:
image = torch.cat([image[:1], image], dim=0)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
if not causal_fix:
t = t[:, :, 1:, :, :]
image = image[1:]
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -369,14 +466,16 @@ class LTXVAddGuide(io.ComfyNode):
t,
strength,
scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=latent_downscale_factor,
causal_fix=causal_fix,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
attention_mask=attention_mask,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
@ -794,6 +893,7 @@ class LtxvExtension(ComfyExtension):
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
GetICLoRAParameters,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,

View File

@ -330,7 +330,7 @@ class FeatherMask(IO.ComfyNode):
for x in range(right):
feather_rate = (x + 1) / right
output[:, :, -x] *= feather_rate
output[:, :, -(x + 1)] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
@ -338,7 +338,7 @@ class FeatherMask(IO.ComfyNode):
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate
output[:, -(y + 1), :] *= feather_rate
return IO.NodeOutput(output)

View File

@ -276,8 +276,8 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
clip.load_model()
clip_sd = clip.state_dict_for_saving()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))

View File

@ -626,7 +626,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:

View File

@ -700,17 +700,19 @@ class LoraLoader:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
lora_metadata = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
self.loaded_lora = (lora_path, lora, lora_metadata)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader):