Compare commits

...

4 Commits

Author SHA1 Message Date
drozbay
292814c31e
feat: Add optional attention_mask input to LTXVAddGuide (CORE-220) (#13965)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
2026-05-19 05:07:04 +08:00
Yousef R. Gamaleldin
187e5237e1
Fix BiRefNet issue (#13966) 2026-05-19 05:03:22 +08:00
Alexander Piskun
164a9d4bbb
[Partner Nodes] add ByteDance Seed LLM node (#13919)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-18 13:06:13 -07:00
rattus
16f862f02a
implement dynamic clip saving (#13959)
Fix clip saving by doing the same patching process and diffusion
models.
2026-05-18 11:46:40 -07:00
7 changed files with 420 additions and 20 deletions

View File

@ -44,7 +44,14 @@ class BackgroundRemovalModel():
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
if pixel_values.shape[0] > 1:
out = torch.cat([
self.model(pixel_values=pixel_values[i:i+1])
for i in range(pixel_values.shape[0])
], dim=0)
else:
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())

View File

@ -1493,27 +1493,30 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
original_state_dict = self.model.diffusion_model.state_dict()
unet_state_dict = {}
def model_state_dict_for_saving(self, model=None, prefix=""):
if model is None:
model = self.model
original_state_dict = model.state_dict()
output_state_dict = {}
keys = list(original_state_dict)
while len(keys) > 0:
k = keys.pop(0)
v = original_state_dict[k]
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
op = comfy.utils.get_attr(model, op_keys[0])
except:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
unet_state_dict[k] = v
output_state_dict[k] = v
continue
key = "diffusion_model." + k
key = prefix + k
weight = comfy.utils.get_attr(self.model, key)
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
qt_state_dict = weight.state_dict(k)
@ -1521,10 +1524,14 @@ class ModelPatcher:
for group_key in (x for x in qt_state_dict if x in original_state_dict):
if group_key in keys:
keys.remove(group_key)
unet_state_dict.pop(group_key, "")
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
output_state_dict.pop(group_key, "")
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
continue
unet_state_dict[k] = LazyCastingParam(self, key, weight)
output_state_dict[k] = LazyCastingParam(self, key, weight)
return output_state_dict
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):

View File

@ -423,6 +423,13 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def state_dict_for_saving(self):
sd_clip = self.patcher.model_state_dict_for_saving()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
clip_sd = clip.state_dict_for_saving()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()

View File

@ -0,0 +1,101 @@
"""Pydantic models for BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
"""
from typing import Literal
from pydantic import BaseModel, Field
class BytePlusInputText(BaseModel):
type: Literal["input_text"] = "input_text"
text: str = Field(...)
class BytePlusInputImage(BaseModel):
type: Literal["input_image"] = "input_image"
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
detail: str = Field("auto", description="One of high, low, auto")
class BytePlusInputVideo(BaseModel):
type: Literal["input_video"] = "input_video"
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
fps: float | None = Field(None, ge=0.2, le=5.0)
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
class BytePlusInputMessage(BaseModel):
type: Literal["message"] = "message"
role: str = Field(..., description="One of user, system, assistant, developer")
content: list[BytePlusMessageContent] = Field(...)
class BytePlusResponseCreateRequest(BaseModel):
model: str = Field(...)
input: list[BytePlusInputMessage] = Field(...)
instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None, ge=1)
temperature: float | None = Field(None, ge=0.0, le=2.0)
store: bool | None = Field(False)
stream: bool | None = Field(False)
class BytePlusOutputText(BaseModel):
type: Literal["output_text"] = "output_text"
text: str = Field(...)
class BytePlusOutputRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str = Field(...)
class BytePlusOutputContent(BaseModel):
type: str = Field(...)
text: str | None = Field(None)
refusal: str | None = Field(None)
class BytePlusOutputMessage(BaseModel):
type: str = Field(...)
id: str | None = Field(None)
role: str | None = Field(None)
status: str | None = Field(None)
content: list[BytePlusOutputContent] | None = Field(None)
class BytePlusInputTokensDetails(BaseModel):
cached_tokens: int | None = Field(None)
class BytePlusOutputTokensDetails(BaseModel):
reasoning_tokens: int | None = Field(None)
class BytePlusResponseUsage(BaseModel):
input_tokens: int | None = Field(None)
output_tokens: int | None = Field(None)
total_tokens: int | None = Field(None)
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
class BytePlusResponseError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class BytePlusResponseObject(BaseModel):
id: str | None = Field(None)
object: str | None = Field(None)
created_at: int | None = Field(None)
model: str | None = Field(None)
status: str | None = Field(None)
error: BytePlusResponseError | None = Field(None)
output: list[BytePlusOutputMessage] | None = Field(None)
usage: BytePlusResponseUsage | None = Field(None)

View File

@ -0,0 +1,271 @@
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
"""
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_llm import (
BytePlusInputImage,
BytePlusInputMessage,
BytePlusInputText,
BytePlusInputVideo,
BytePlusMessageContent,
BytePlusResponseCreateRequest,
BytePlusResponseObject,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
SEED_MAX_IMAGES = 20
SEED_MAX_VIDEOS = 4
SEED_MODELS: dict[str, str] = {
"Seed 2.0 Pro": "seed-2-0-pro-260328",
"Seed 2.0 Lite": "seed-2-0-lite-260228",
"Seed 2.0 Mini": "seed-2-0-mini-260215",
}
# USD per 1M tokens: (input, cache_hit_input, output)
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
}
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_images + 1)],
min=0,
),
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
),
IO.Autogrow.Input(
"videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=[f"video_{i}" for i in range(1, max_videos + 1)],
min=0,
),
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
advanced=True,
),
]
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
"""Compute approximate USD price from response usage."""
if not response.usage:
return None
rates = _SEED_PRICES_PER_MILLION.get(model_id)
if rates is None:
return None
input_rate, cache_hit_rate, output_rate = rates
input_tokens = response.usage.input_tokens or 0
output_tokens = response.usage.output_tokens or 0
cached = 0
if response.usage.input_tokens_details:
cached = response.usage.input_tokens_details.cached_tokens or 0
fresh_input = max(0, input_tokens - cached)
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
return total / 1_000_000.0
def _get_text_from_response(response: BytePlusResponseObject) -> str:
"""Extract concatenated text from all assistant message output_text blocks."""
if not response.output:
return ""
chunks: list[str] = []
for item in response.output:
if item.type != "message" or not item.content:
continue
for block in item.content:
if block.type == "output_text" and block.text:
chunks.append(block.text)
elif block.type == "refusal" and block.refusal:
raise ValueError(f"Model refused to respond: {block.refusal}")
return "\n".join(chunks)
async def _build_image_content_blocks(
cls: type[IO.ComfyNode],
image_tensors: list[Input.Image],
) -> list[BytePlusInputImage]:
urls = await upload_images_to_comfyapi(
cls,
image_tensors,
max_images=SEED_MAX_IMAGES,
wait_label="Uploading reference images",
)
return [BytePlusInputImage(image_url=url) for url in urls]
async def _build_video_content_blocks(
cls: type[IO.ComfyNode],
videos: list[Input.Video],
) -> list[BytePlusInputVideo]:
blocks: list[BytePlusInputVideo] = []
total = len(videos)
for idx, video in enumerate(videos):
label = "Uploading reference video"
if total > 1:
label = f"{label} ({idx + 1}/{total})"
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
blocks.append(BytePlusInputVideo(video_url=url))
return blocks
class ByteDanceSeedNode(IO.ComfyNode):
"""Generate text responses from a ByteDance Seed 2.0 model."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="api node/text/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
tooltip="The Seed model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "mini") ? {
"type": "list_usd",
"usd": [0.00025, 0.0009],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.0003, 0.002],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "pro") ? {
"type": "list_usd",
"usd": [0.0005, 0.003],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
temperature = model["temperature"]
model_id = SEED_MODELS[model_label]
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
if len(video_inputs) > SEED_MAX_VIDEOS:
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
content: list[BytePlusMessageContent] = []
if image_tensors:
content.extend(await _build_image_content_blocks(cls, image_tensors))
if video_inputs:
content.extend(await _build_video_content_blocks(cls, video_inputs))
content.append(BytePlusInputText(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
response_model=BytePlusResponseObject,
data=BytePlusResponseCreateRequest(
model=model_id,
input=[BytePlusInputMessage(role="user", content=content)],
instructions=system_prompt or None,
temperature=temperature,
store=False,
stream=False,
),
price_extractor=lambda r: _calculate_price(model_id, r),
)
if response.error:
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
result = _get_text_from_response(response)
if not result:
raise ValueError("Empty response from Seed model.")
return IO.NodeOutput(result)
class ByteDanceLLMExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ByteDanceSeedNode]
async def comfy_entrypoint() -> ByteDanceLLMExtension:
return ByteDanceLLMExtension()

View File

@ -175,7 +175,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
@ -184,9 +184,10 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W)
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
@ -196,8 +197,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
# Shallow copy only and append (pixel_mask is never mutated).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
@ -263,6 +263,12 @@ class LTXVAddGuide(io.ComfyNode):
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Mask.Input(
"attention_mask",
optional=True,
tooltip="Optional pixel-space spatial mask. Controls per-region "
"conditioning influence via self-attention, multiplied by strength.",
),
ICLoRAParameters.Input(
"iclora_parameters",
optional=True,
@ -410,7 +416,7 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
@ -469,6 +475,7 @@ class LTXVAddGuide(io.ComfyNode):
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
attention_mask=attention_mask,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

View File

@ -276,8 +276,8 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
clip.load_model()
clip_sd = clip.state_dict_for_saving()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))