mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 22:39:24 +08:00
Compare commits
4 Commits
d4c6c9eff8
...
292814c31e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
292814c31e | ||
|
|
187e5237e1 | ||
|
|
164a9d4bbb | ||
|
|
16f862f02a |
@ -44,7 +44,14 @@ class BackgroundRemovalModel():
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
H, W = image.shape[1], image.shape[2]
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
|
||||
if pixel_values.shape[0] > 1:
|
||||
out = torch.cat([
|
||||
self.model(pixel_values=pixel_values[i:i+1])
|
||||
for i in range(pixel_values.shape[0])
|
||||
], dim=0)
|
||||
else:
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
@ -1493,27 +1493,30 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
original_state_dict = self.model.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
def model_state_dict_for_saving(self, model=None, prefix=""):
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
original_state_dict = model.state_dict()
|
||||
output_state_dict = {}
|
||||
keys = list(original_state_dict)
|
||||
while len(keys) > 0:
|
||||
k = keys.pop(0)
|
||||
v = original_state_dict[k]
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
op = comfy.utils.get_attr(model, op_keys[0])
|
||||
except:
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
key = prefix + k
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||
qt_state_dict = weight.state_dict(k)
|
||||
@ -1521,10 +1524,14 @@ class ModelPatcher:
|
||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||
if group_key in keys:
|
||||
keys.remove(group_key)
|
||||
unet_state_dict.pop(group_key, "")
|
||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||
output_state_dict.pop(group_key, "")
|
||||
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
|
||||
continue
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
output_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return output_state_dict
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -423,6 +423,13 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def state_dict_for_saving(self):
|
||||
sd_clip = self.patcher.model_state_dict_for_saving()
|
||||
sd_tokenizer = self.tokenizer.state_dict()
|
||||
for k in sd_tokenizer:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
101
comfy_api_nodes/apis/bytedance_llm.py
Normal file
101
comfy_api_nodes/apis/bytedance_llm.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""Pydantic models for BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
|
||||
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BytePlusInputText(BaseModel):
|
||||
type: Literal["input_text"] = "input_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusInputImage(BaseModel):
|
||||
type: Literal["input_image"] = "input_image"
|
||||
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
|
||||
detail: str = Field("auto", description="One of high, low, auto")
|
||||
|
||||
|
||||
class BytePlusInputVideo(BaseModel):
|
||||
type: Literal["input_video"] = "input_video"
|
||||
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
|
||||
fps: float | None = Field(None, ge=0.2, le=5.0)
|
||||
|
||||
|
||||
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
|
||||
|
||||
|
||||
class BytePlusInputMessage(BaseModel):
|
||||
type: Literal["message"] = "message"
|
||||
role: str = Field(..., description="One of user, system, assistant, developer")
|
||||
content: list[BytePlusMessageContent] = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseCreateRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: list[BytePlusInputMessage] = Field(...)
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None, ge=1)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
store: bool | None = Field(False)
|
||||
stream: bool | None = Field(False)
|
||||
|
||||
|
||||
class BytePlusOutputText(BaseModel):
|
||||
type: Literal["output_text"] = "output_text"
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputRefusal(BaseModel):
|
||||
type: Literal["refusal"] = "refusal"
|
||||
refusal: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusOutputContent(BaseModel):
|
||||
type: str = Field(...)
|
||||
text: str | None = Field(None)
|
||||
refusal: str | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputMessage(BaseModel):
|
||||
type: str = Field(...)
|
||||
id: str | None = Field(None)
|
||||
role: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
content: list[BytePlusOutputContent] | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusInputTokensDetails(BaseModel):
|
||||
cached_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusOutputTokensDetails(BaseModel):
|
||||
reasoning_tokens: int | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseUsage(BaseModel):
|
||||
input_tokens: int | None = Field(None)
|
||||
output_tokens: int | None = Field(None)
|
||||
total_tokens: int | None = Field(None)
|
||||
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
|
||||
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
|
||||
|
||||
|
||||
class BytePlusResponseError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class BytePlusResponseObject(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
object: str | None = Field(None)
|
||||
created_at: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
status: str | None = Field(None)
|
||||
error: BytePlusResponseError | None = Field(None)
|
||||
output: list[BytePlusOutputMessage] | None = Field(None)
|
||||
usage: BytePlusResponseUsage | None = Field(None)
|
||||
271
comfy_api_nodes/nodes_bytedance_llm.py
Normal file
271
comfy_api_nodes/nodes_bytedance_llm.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
|
||||
|
||||
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_llm import (
|
||||
BytePlusInputImage,
|
||||
BytePlusInputMessage,
|
||||
BytePlusInputText,
|
||||
BytePlusInputVideo,
|
||||
BytePlusMessageContent,
|
||||
BytePlusResponseCreateRequest,
|
||||
BytePlusResponseObject,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
|
||||
SEED_MAX_IMAGES = 20
|
||||
SEED_MAX_VIDEOS = 4
|
||||
|
||||
SEED_MODELS: dict[str, str] = {
|
||||
"Seed 2.0 Pro": "seed-2-0-pro-260328",
|
||||
"Seed 2.0 Lite": "seed-2-0-lite-260228",
|
||||
"Seed 2.0 Mini": "seed-2-0-mini-260215",
|
||||
}
|
||||
|
||||
# USD per 1M tokens: (input, cache_hit_input, output)
|
||||
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
|
||||
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
|
||||
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
|
||||
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
|
||||
}
|
||||
|
||||
|
||||
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_images + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=[f"video_{i}" for i in range(1, max_videos + 1)],
|
||||
min=0,
|
||||
),
|
||||
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
|
||||
"""Compute approximate USD price from response usage."""
|
||||
if not response.usage:
|
||||
return None
|
||||
rates = _SEED_PRICES_PER_MILLION.get(model_id)
|
||||
if rates is None:
|
||||
return None
|
||||
input_rate, cache_hit_rate, output_rate = rates
|
||||
input_tokens = response.usage.input_tokens or 0
|
||||
output_tokens = response.usage.output_tokens or 0
|
||||
cached = 0
|
||||
if response.usage.input_tokens_details:
|
||||
cached = response.usage.input_tokens_details.cached_tokens or 0
|
||||
fresh_input = max(0, input_tokens - cached)
|
||||
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
|
||||
return total / 1_000_000.0
|
||||
|
||||
|
||||
def _get_text_from_response(response: BytePlusResponseObject) -> str:
|
||||
"""Extract concatenated text from all assistant message output_text blocks."""
|
||||
if not response.output:
|
||||
return ""
|
||||
chunks: list[str] = []
|
||||
for item in response.output:
|
||||
if item.type != "message" or not item.content:
|
||||
continue
|
||||
for block in item.content:
|
||||
if block.type == "output_text" and block.text:
|
||||
chunks.append(block.text)
|
||||
elif block.type == "refusal" and block.refusal:
|
||||
raise ValueError(f"Model refused to respond: {block.refusal}")
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
async def _build_image_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
image_tensors: list[Input.Image],
|
||||
) -> list[BytePlusInputImage]:
|
||||
urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensors,
|
||||
max_images=SEED_MAX_IMAGES,
|
||||
wait_label="Uploading reference images",
|
||||
)
|
||||
return [BytePlusInputImage(image_url=url) for url in urls]
|
||||
|
||||
|
||||
async def _build_video_content_blocks(
|
||||
cls: type[IO.ComfyNode],
|
||||
videos: list[Input.Video],
|
||||
) -> list[BytePlusInputVideo]:
|
||||
blocks: list[BytePlusInputVideo] = []
|
||||
total = len(videos)
|
||||
for idx, video in enumerate(videos):
|
||||
label = "Uploading reference video"
|
||||
if total > 1:
|
||||
label = f"{label} ({idx + 1}/{total})"
|
||||
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
|
||||
blocks.append(BytePlusInputVideo(video_url=url))
|
||||
return blocks
|
||||
|
||||
|
||||
class ByteDanceSeedNode(IO.ComfyNode):
|
||||
"""Generate text responses from a ByteDance Seed 2.0 model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="api node/text/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
|
||||
tooltip="The Seed model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "mini") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0009],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0003, 0.002],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0005, 0.003],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: {"type":"text", "text":"Token-based"}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_label = model["model"]
|
||||
temperature = model["temperature"]
|
||||
model_id = SEED_MODELS[model_label]
|
||||
|
||||
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
|
||||
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
|
||||
|
||||
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
|
||||
if len(video_inputs) > SEED_MAX_VIDEOS:
|
||||
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
|
||||
|
||||
content: list[BytePlusMessageContent] = []
|
||||
if image_tensors:
|
||||
content.extend(await _build_image_content_blocks(cls, image_tensors))
|
||||
if video_inputs:
|
||||
content.extend(await _build_video_content_blocks(cls, video_inputs))
|
||||
content.append(BytePlusInputText(text=prompt))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
|
||||
response_model=BytePlusResponseObject,
|
||||
data=BytePlusResponseCreateRequest(
|
||||
model=model_id,
|
||||
input=[BytePlusInputMessage(role="user", content=content)],
|
||||
instructions=system_prompt or None,
|
||||
temperature=temperature,
|
||||
store=False,
|
||||
stream=False,
|
||||
),
|
||||
price_extractor=lambda r: _calculate_price(model_id, r),
|
||||
)
|
||||
if response.error:
|
||||
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
|
||||
result = _get_text_from_response(response)
|
||||
if not result:
|
||||
raise ValueError("Empty response from Seed model.")
|
||||
return IO.NodeOutput(result)
|
||||
|
||||
|
||||
class ByteDanceLLMExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [ByteDanceSeedNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ByteDanceLLMExtension:
|
||||
return ByteDanceLLMExtension()
|
||||
@ -175,7 +175,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
@ -184,9 +184,10 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
new_entry = {
|
||||
"pre_filter_count": pre_filter_count,
|
||||
"strength": strength,
|
||||
"pixel_mask": None,
|
||||
"pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W)
|
||||
"latent_shape": latent_shape,
|
||||
}
|
||||
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
# Read existing entries from this specific conditioning
|
||||
@ -196,8 +197,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
if found is not None:
|
||||
existing = found
|
||||
break
|
||||
# Shallow copy and append (no deepcopy needed — entries contain
|
||||
# only scalars and None for pixel_mask at this call site).
|
||||
# Shallow copy only and append (pixel_mask is never mutated).
|
||||
entries = [*existing, new_entry]
|
||||
results.append(node_helpers.conditioning_set_values(
|
||||
cond, {"guide_attention_entries": entries}
|
||||
@ -263,6 +263,12 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||
io.Mask.Input(
|
||||
"attention_mask",
|
||||
optional=True,
|
||||
tooltip="Optional pixel-space spatial mask. Controls per-region "
|
||||
"conditioning influence via self-attention, multiplied by strength.",
|
||||
),
|
||||
ICLoRAParameters.Input(
|
||||
"iclora_parameters",
|
||||
optional=True,
|
||||
@ -410,7 +416,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return latent_image, noise_mask
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput:
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
@ -469,6 +475,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
@ -276,8 +276,8 @@ class CLIPSave:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
clip.load_model()
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user