diff --git a/README.md b/README.md index 64d494f20..0eecd8a4b 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... - ComfyUI natively supports the latest open-source state of the art models. - API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. -- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. +- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud). - The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. - It integrates seamlessly into production pipelines with our API endpoints. diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py index 7877afd7f..6dec65e63 100644 --- a/comfy/bg_removal_model.py +++ b/comfy/bg_removal_model.py @@ -44,7 +44,14 @@ class BackgroundRemovalModel(): comfy.model_management.load_model_gpu(self.patcher) H, W = image.shape[1], image.shape[2] pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False) - out = self.model(pixel_values=pixel_values) + + if pixel_values.shape[0] > 1: + out = torch.cat([ + self.model(pixel_values=pixel_values[i:i+1]) + for i in range(pixel_values.shape[0]) + ], dim=0) + else: + out = self.model(pixel_values=pixel_values) out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0736321b3..c22705655 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1691,6 +1691,13 @@ class HiDreamO1(BaseModel): if text_input_ids is None or noise is None: return out + # handle area conds + area = kwargs.get("area", None) + if area is not None: + crop_h = min(noise.shape[-2] - area[2], area[0]) + crop_w = min(noise.shape[-1] - area[3], area[1]) + noise = torch.empty((noise.shape[0], 3, crop_h, crop_w), dtype=noise.dtype, device=noise.device) + conds = build_extra_conds( text_input_ids, noise, ref_images=kwargs.get("reference_latents", None), diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2ea14bc2c..4f9d8403e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1493,27 +1493,30 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - original_state_dict = self.model.diffusion_model.state_dict() - unet_state_dict = {} + def model_state_dict_for_saving(self, model=None, prefix=""): + if model is None: + model = self.model + + original_state_dict = model.state_dict() + output_state_dict = {} keys = list(original_state_dict) while len(keys) > 0: k = keys.pop(0) v = original_state_dict[k] op_keys = k.rsplit('.', 1) if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: - unet_state_dict[k] = v + output_state_dict[k] = v continue try: - op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + op = comfy.utils.get_attr(model, op_keys[0]) except: - unet_state_dict[k] = v + output_state_dict[k] = v continue if not op or not hasattr(op, "comfy_cast_weights") or \ (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): - unet_state_dict[k] = v + output_state_dict[k] = v continue - key = "diffusion_model." + k + key = prefix + k weight = comfy.utils.get_attr(self.model, key) if isinstance(weight, QuantizedTensor) and k in original_state_dict: qt_state_dict = weight.state_dict(k) @@ -1521,10 +1524,14 @@ class ModelPatcher: for group_key in (x for x in qt_state_dict if x in original_state_dict): if group_key in keys: keys.remove(group_key) - unet_state_dict.pop(group_key, "") - unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key]) + output_state_dict.pop(group_key, "") + output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key]) continue - unet_state_dict[k] = LazyCastingParam(self, key, weight) + output_state_dict[k] = LazyCastingParam(self, key, weight) + return output_state_dict + + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.") return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self): diff --git a/comfy/ops.py b/comfy/ops.py index 117cdd327..f9456854b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1376,6 +1376,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") + logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else "")) return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( diff --git a/comfy/sd.py b/comfy/sd.py index ab2718892..2443353a4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -79,7 +79,7 @@ import comfy.latent_formats import comfy.ldm.flux.redux -def load_lora_for_models(model, clip, lora, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) @@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if model is not None: new_modelpatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) + if lora_metadata: + new_modelpatcher.set_attachments("lora_metadata", lora_metadata) else: k = () new_modelpatcher = None @@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): if clip is not None: new_clip = clip.clone() k1 = new_clip.add_patches(loaded, strength_clip) + if lora_metadata: + new_clip.patcher.set_attachments("lora_metadata", lora_metadata) else: k1 = () new_clip = None @@ -419,6 +423,13 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip + def state_dict_for_saving(self): + sd_clip = self.patcher.model_state_dict_for_saving() + sd_tokenizer = self.tokenizer.state_dict() + for k in sd_tokenizer: + sd_clip[k] = sd_tokenizer[k] + return sd_clip + def load_model(self, tokens={}): memory_used = 0 if hasattr(self.cond_stage_model, "memory_estimation_function"): @@ -1904,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m load_models = [model] if clip is not None: load_models.append(clip.load_model()) - clip_sd = clip.get_sd() + clip_sd = clip.state_dict_for_saving() vae_sd = None if vae is not None: vae_sd = vae.get_sd() diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index b022009b1..416ce9d18 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -760,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs): image = kwargs.get("image", None) if image is not None and len(images) == 0: - images = [image] + images = [image[i:i + 1] for i in range(image.shape[0])] skip_template = False if text.startswith('<|im_start|>'): @@ -771,13 +771,16 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): if skip_template: llama_text = text else: - if llama_template is None: - if len(images) > 0: - llama_text = self.llama_template_images.format(text) - else: - llama_text = self.llama_template.format(text) + if llama_template is not None: + template = llama_template + elif len(images) == 0: + template = self.llama_template else: - llama_text = llama_template.format(text) + template = self.llama_template_images + if len(images) > 1: + vision_block = "<|vision_start|><|image_pad|><|vision_end|>" + template = template.replace(vision_block, vision_block * len(images), 1) + llama_text = template.format(text) if not thinking: llama_text += "\n\n" diff --git a/comfy_api_nodes/apis/bytedance_llm.py b/comfy_api_nodes/apis/bytedance_llm.py new file mode 100644 index 000000000..654c875fc --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_llm.py @@ -0,0 +1,101 @@ +"""Pydantic models for BytePlus ModelArk Responses API. + +See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request) + https://docs.byteplus.com/en/docs/ModelArk/1783703 (response) +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class BytePlusInputText(BaseModel): + type: Literal["input_text"] = "input_text" + text: str = Field(...) + + +class BytePlusInputImage(BaseModel): + type: Literal["input_image"] = "input_image" + image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload") + detail: str = Field("auto", description="One of high, low, auto") + + +class BytePlusInputVideo(BaseModel): + type: Literal["input_video"] = "input_video" + video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload") + fps: float | None = Field(None, ge=0.2, le=5.0) + + +BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo + + +class BytePlusInputMessage(BaseModel): + type: Literal["message"] = "message" + role: str = Field(..., description="One of user, system, assistant, developer") + content: list[BytePlusMessageContent] = Field(...) + + +class BytePlusResponseCreateRequest(BaseModel): + model: str = Field(...) + input: list[BytePlusInputMessage] = Field(...) + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None, ge=1) + temperature: float | None = Field(None, ge=0.0, le=2.0) + store: bool | None = Field(False) + stream: bool | None = Field(False) + + +class BytePlusOutputText(BaseModel): + type: Literal["output_text"] = "output_text" + text: str = Field(...) + + +class BytePlusOutputRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str = Field(...) + + +class BytePlusOutputContent(BaseModel): + type: str = Field(...) + text: str | None = Field(None) + refusal: str | None = Field(None) + + +class BytePlusOutputMessage(BaseModel): + type: str = Field(...) + id: str | None = Field(None) + role: str | None = Field(None) + status: str | None = Field(None) + content: list[BytePlusOutputContent] | None = Field(None) + + +class BytePlusInputTokensDetails(BaseModel): + cached_tokens: int | None = Field(None) + + +class BytePlusOutputTokensDetails(BaseModel): + reasoning_tokens: int | None = Field(None) + + +class BytePlusResponseUsage(BaseModel): + input_tokens: int | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + input_tokens_details: BytePlusInputTokensDetails | None = Field(None) + output_tokens_details: BytePlusOutputTokensDetails | None = Field(None) + + +class BytePlusResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class BytePlusResponseObject(BaseModel): + id: str | None = Field(None) + object: str | None = Field(None) + created_at: int | None = Field(None) + model: str | None = Field(None) + status: str | None = Field(None) + error: BytePlusResponseError | None = Field(None) + output: list[BytePlusOutputMessage] | None = Field(None) + usage: BytePlusResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 60e1624f7..28dd70d4e 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -49,7 +49,7 @@ def _claude_model_inputs(): min=0.0, max=1.0, step=0.01, - tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random.", + tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.", advanced=True, ), ] @@ -208,7 +208,7 @@ class ClaudeNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) model_label = model["model"] max_tokens = model["max_tokens"] - temperature = model["temperature"] + temperature = None if model_label == "Opus 4.7" else model["temperature"] image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None] if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES: diff --git a/comfy_api_nodes/nodes_bytedance_llm.py b/comfy_api_nodes/nodes_bytedance_llm.py new file mode 100644 index 000000000..fa7fe370a --- /dev/null +++ b/comfy_api_nodes/nodes_bytedance_llm.py @@ -0,0 +1,271 @@ +"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API. + +See: https://docs.byteplus.com/en/docs/ModelArk/1585128 +""" + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_llm import ( + BytePlusInputImage, + BytePlusInputMessage, + BytePlusInputText, + BytePlusInputVideo, + BytePlusMessageContent, + BytePlusResponseCreateRequest, + BytePlusResponseObject, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op, + upload_images_to_comfyapi, + upload_video_to_comfyapi, + validate_string, +) + +BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses" +SEED_MAX_IMAGES = 20 +SEED_MAX_VIDEOS = 4 + +SEED_MODELS: dict[str, str] = { + "Seed 2.0 Pro": "seed-2-0-pro-260328", + "Seed 2.0 Lite": "seed-2-0-lite-260228", + "Seed 2.0 Mini": "seed-2-0-mini-260215", +} + +# USD per 1M tokens: (input, cache_hit_input, output) +_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = { + "seed-2-0-pro-260328": (0.50, 0.10, 3.00), + "seed-2-0-lite-260228": (0.25, 0.05, 2.00), + "seed-2-0-mini-260215": (0.10, 0.02, 0.40), +} + + +def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS): + return [ + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, max_images + 1)], + min=0, + ), + tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.", + ), + IO.Autogrow.Input( + "videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=[f"video_{i}" for i in range(1, max_videos + 1)], + min=0, + ), + tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.", + advanced=True, + ), + ] + + +def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None: + """Compute approximate USD price from response usage.""" + if not response.usage: + return None + rates = _SEED_PRICES_PER_MILLION.get(model_id) + if rates is None: + return None + input_rate, cache_hit_rate, output_rate = rates + input_tokens = response.usage.input_tokens or 0 + output_tokens = response.usage.output_tokens or 0 + cached = 0 + if response.usage.input_tokens_details: + cached = response.usage.input_tokens_details.cached_tokens or 0 + fresh_input = max(0, input_tokens - cached) + total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate + return total / 1_000_000.0 + + +def _get_text_from_response(response: BytePlusResponseObject) -> str: + """Extract concatenated text from all assistant message output_text blocks.""" + if not response.output: + return "" + chunks: list[str] = [] + for item in response.output: + if item.type != "message" or not item.content: + continue + for block in item.content: + if block.type == "output_text" and block.text: + chunks.append(block.text) + elif block.type == "refusal" and block.refusal: + raise ValueError(f"Model refused to respond: {block.refusal}") + return "\n".join(chunks) + + +async def _build_image_content_blocks( + cls: type[IO.ComfyNode], + image_tensors: list[Input.Image], +) -> list[BytePlusInputImage]: + urls = await upload_images_to_comfyapi( + cls, + image_tensors, + max_images=SEED_MAX_IMAGES, + wait_label="Uploading reference images", + ) + return [BytePlusInputImage(image_url=url) for url in urls] + + +async def _build_video_content_blocks( + cls: type[IO.ComfyNode], + videos: list[Input.Video], +) -> list[BytePlusInputVideo]: + blocks: list[BytePlusInputVideo] = [] + total = len(videos) + for idx, video in enumerate(videos): + label = "Uploading reference video" + if total > 1: + label = f"{label} ({idx + 1}/{total})" + url = await upload_video_to_comfyapi(cls, video, wait_label=label) + blocks.append(BytePlusInputVideo(video_url=url)) + return blocks + + +class ByteDanceSeedNode(IO.ComfyNode): + """Generate text responses from a ByteDance Seed 2.0 model.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDanceSeedNode", + display_name="ByteDance Seed", + category="api node/text/ByteDance", + essentials_category="Text Generation", + description="Generate text responses with ByteDance's Seed 2.0 models. " + "Provide a text prompt and optionally one or more images or videos for multimodal context.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text input to the model.", + ), + IO.DynamicCombo.Input( + "model", + options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS], + tooltip="The Seed model used to generate the response.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + advanced=True, + tooltip="Foundational instructions that dictate the model's behavior.", + ), + ], + outputs=[IO.String.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "mini") ? { + "type": "list_usd", + "usd": [0.00025, 0.0009], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "lite") ? { + "type": "list_usd", + "usd": [0.0003, 0.002], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "pro") ? { + "type": "list_usd", + "usd": [0.0005, 0.003], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type":"text", "text":"Token-based"} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + model_label = model["model"] + temperature = model["temperature"] + model_id = SEED_MODELS[model_label] + + image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None] + if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES: + raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.") + + video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None] + if len(video_inputs) > SEED_MAX_VIDEOS: + raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.") + + content: list[BytePlusMessageContent] = [] + if image_tensors: + content.extend(await _build_image_content_blocks(cls, image_tensors)) + if video_inputs: + content.extend(await _build_video_content_blocks(cls, video_inputs)) + content.append(BytePlusInputText(text=prompt)) + + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"), + response_model=BytePlusResponseObject, + data=BytePlusResponseCreateRequest( + model=model_id, + input=[BytePlusInputMessage(role="user", content=content)], + instructions=system_prompt or None, + temperature=temperature, + store=False, + stream=False, + ), + price_extractor=lambda r: _calculate_price(model_id, r), + ) + if response.error: + raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}") + result = _get_text_from_response(response) + if not result: + raise ValueError("Empty response from Seed model.") + return IO.NodeOutput(result) + + +class ByteDanceLLMExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ByteDanceSeedNode] + + +async def comfy_entrypoint() -> ByteDanceLLMExtension: + return ByteDanceLLMExtension() diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 8b32d22ba..50e07e89a 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -14,6 +14,49 @@ from typing_extensions import override from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from comfy_api.latest import ComfyExtension, io +ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS") + + +class GetICLoRAParameters(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GetICLoRAParameters", + display_name="Get IC-LoRA Parameters", + description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded " + "model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).", + category="conditioning/video_models", + search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"], + inputs=[ + io.Model.Input( + "iclora_model", + tooltip="Direct output from a LoRA Loader for the specific IC-LoRA " + "from which to extract the metadata.", + ), + ], + outputs=[ + ICLoRAParameters.Output( + "iclora_parameters", + tooltip="IC-LoRA parameters extracted from the LoRA metadata " + "(eg. reference_downscale_factor). Connect to LTXVAddGuide " + "if the LoRA requires special handling of the guides.", + ), + ], + ) + + @classmethod + def execute(cls, iclora_model) -> io.NodeOutput: + metadata = iclora_model.get_attachment("lora_metadata") + factor = 1 + if metadata: + try: + factor = max(1, round(float(metadata.get("reference_downscale_factor", 1)))) + except (TypeError, ValueError): + factor = 1 + parameters = {"reference_downscale_factor": factor} + return io.NodeOutput(parameters) + + class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -132,7 +175,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): generate = execute # TODO: remove -def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None): """Append a guide_attention_entry to both positive and negative conditioning. Each entry tracks one guide reference for per-reference attention control. @@ -141,9 +184,10 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s new_entry = { "pre_filter_count": pre_filter_count, "strength": strength, - "pixel_mask": None, + "pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W) "latent_shape": latent_shape, } + results = [] for cond in (positive, negative): # Read existing entries from this specific conditioning @@ -153,8 +197,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s if found is not None: existing = found break - # Shallow copy and append (no deepcopy needed — entries contain - # only scalars and None for pixel_mask at this call site). + # Shallow copy only and append (pixel_mask is never mutated). entries = [*existing, new_entry] results.append(node_helpers.conditioning_set_values( cond, {"guide_attention_entries": entries} @@ -220,6 +263,20 @@ class LTXVAddGuide(io.ComfyNode): "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Mask.Input( + "attention_mask", + optional=True, + tooltip="Optional pixel-space spatial mask. Controls per-region " + "conditioning influence via self-attention, multiplied by strength.", + ), + ICLoRAParameters.Input( + "iclora_parameters", + optional=True, + tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. " + "Used for adjusting guide processing as required by certain IC-LoRAs " + "(eg. those with a reference_downscale_factor > 1). " + "When chained, each LTXVAddGuide uses only the parameters connected to it.", + ), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -229,14 +286,41 @@ class LTXVAddGuide(io.ComfyNode): ) @classmethod - def encode(cls, vae, latent_width, latent_height, images, scale_factors): + def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1): time_scale_factor, width_scale_factor, height_scale_factor = scale_factors images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] - pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1) + target_width = int(latent_width * width_scale_factor / latent_downscale_factor) + target_height = int(latent_height * height_scale_factor / latent_downscale_factor) + pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1) encode_pixels = pixels[:, :, :, :3] t = vae.encode(encode_pixels) return encode_pixels, t + @classmethod + def dilate_latent(cls, guide_latent, latent_downscale_factor): + if latent_downscale_factor <= 1: + return guide_latent, None + scale = int(latent_downscale_factor) + dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale) + dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype) + dilated[..., ::scale, ::scale] = guide_latent + dilated_mask = torch.full( + (dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]), + -1.0, device=guide_latent.device, dtype=guide_latent.dtype, + ) + dilated_mask[..., ::scale, ::scale] = 1.0 + return dilated, dilated_mask + + @classmethod + def get_reference_downscale_factor(cls, iclora_parameters): + if not iclora_parameters: + return 1 + try: + factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1)))) + except (TypeError, ValueError): + factor = 1 + return factor + @classmethod def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors @@ -332,13 +416,21 @@ class LTXVAddGuide(io.ComfyNode): return latent_image, noise_mask @classmethod - def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput: scale_factors = vae.downscale_index_formula latent_image = latent["samples"] noise_mask = get_noise_mask(latent) _, _, latent_length, latent_height, latent_width = latent_image.shape + latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters) + if latent_downscale_factor > 1: + if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0: + raise ValueError( + f"Latent spatial size {latent_width}x{latent_height} must be divisible by " + f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters." + ) + # For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot time_scale_factor = scale_factors[0] num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 @@ -351,12 +443,17 @@ class LTXVAddGuide(io.ComfyNode): if not causal_fix: image = torch.cat([image[:1], image], dim=0) - image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) + image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor) if not causal_fix: t = t[:, :, 1:, :, :] image = image[1:] + guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling + guide_mask = None + if latent_downscale_factor > 1: + t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." @@ -369,14 +466,16 @@ class LTXVAddGuide(io.ComfyNode): t, strength, scale_factors, + guide_mask=guide_mask, + latent_downscale_factor=latent_downscale_factor, causal_fix=causal_fix, ) # Track this guide for per-reference attention control. pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] - guide_latent_shape = list(t.shape[2:]) # [F, H, W] positive, negative = _append_guide_attention_entry( positive, negative, pre_filter_count, guide_latent_shape, strength=strength, + attention_mask=attention_mask, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) @@ -794,6 +893,7 @@ class LtxvExtension(ComfyExtension): ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, + GetICLoRAParameters, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 96ee1a0f8..419e561ba 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -330,7 +330,7 @@ class FeatherMask(IO.ComfyNode): for x in range(right): feather_rate = (x + 1) / right - output[:, :, -x] *= feather_rate + output[:, :, -(x + 1)] *= feather_rate for y in range(top): feather_rate = (y + 1) / top @@ -338,7 +338,7 @@ class FeatherMask(IO.ComfyNode): for y in range(bottom): feather_rate = (y + 1) / bottom - output[:, -y, :] *= feather_rate + output[:, -(y + 1), :] *= feather_rate return IO.NodeOutput(output) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index 5384ed531..b6b29e34a 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -276,8 +276,8 @@ class CLIPSave: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) - comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) - clip_sd = clip.get_sd() + clip.load_model() + clip_sd = clip.state_dict_for_saving() for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]: k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) diff --git a/execution.py b/execution.py index f37d0360d..4c7de2e84 100644 --- a/execution.py +++ b/execution.py @@ -626,7 +626,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." - logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) + logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type: diff --git a/nodes.py b/nodes.py index a59e8ebde..374217eea 100644 --- a/nodes.py +++ b/nodes.py @@ -700,17 +700,19 @@ class LoraLoader: lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) lora = None + lora_metadata = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] + lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None else: self.loaded_lora = None if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - self.loaded_lora = (lora_path, lora) + lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True) + self.loaded_lora = (lora_path, lora, lora_metadata) - model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader):