From f66183a54142be693ab858e9f1f06ed62439a92e Mon Sep 17 00:00:00 2001 From: guill Date: Sun, 23 Nov 2025 22:56:20 -0800 Subject: [PATCH 01/39] [fix] Fixes non-async public API access (#10857) It looks like the synchronous version of the public API broke due to an addition of `from __future__ import annotations`. This change updates the async-to-sync adapter to work with both types of type annotations. --- comfy_api/internal/async_to_sync.py | 47 ++++++--- tests/execution/test_public_api.py | 153 ++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 16 deletions(-) create mode 100644 tests/execution/test_public_api.py diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py index f5f805a62..257ade82e 100644 --- a/comfy_api/internal/async_to_sync.py +++ b/comfy_api/internal/async_to_sync.py @@ -8,7 +8,7 @@ import os import textwrap import threading from enum import Enum -from typing import Optional, Type, get_origin, get_args +from typing import Optional, Type, get_origin, get_args, get_type_hints class TypeTracker: @@ -220,11 +220,18 @@ class AsyncToSyncConverter: self._async_instance = async_class(*args, **kwargs) # Handle annotated class attributes (like execution: Execution) - # Get all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Get all annotations from the class hierarchy and resolve string annotations + try: + # get_type_hints resolves string annotations to actual type objects + # This handles classes using 'from __future__ import annotations' + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations if get_type_hints fails + # (e.g., for undefined forward references) + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) # For each annotated attribute, check if it needs to be created or wrapped for attr_name, attr_type in all_annotations.items(): @@ -625,15 +632,19 @@ class AsyncToSyncConverter: """Extract class attributes that are classes themselves.""" class_attributes = [] + # Get resolved type hints to handle string annotations + try: + type_hints = get_type_hints(async_class) + except Exception: + type_hints = {} + # Look for class attributes that are classes for name, attr in sorted(inspect.getmembers(async_class)): if isinstance(attr, type) and not name.startswith("_"): class_attributes.append((name, attr)) - elif ( - hasattr(async_class, "__annotations__") - and name in async_class.__annotations__ - ): - annotation = async_class.__annotations__[name] + elif name in type_hints: + # Use resolved type hint instead of raw annotation + annotation = type_hints[name] if isinstance(annotation, type): class_attributes.append((name, annotation)) @@ -908,11 +919,15 @@ class AsyncToSyncConverter: attribute_mappings = {} # First check annotations for typed attributes (including from parent classes) - # Collect all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Resolve string annotations to actual types + try: + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) for attr_name, attr_type in sorted(all_annotations.items()): for class_name, class_type in class_attributes: diff --git a/tests/execution/test_public_api.py b/tests/execution/test_public_api.py new file mode 100644 index 000000000..52bc2fcd8 --- /dev/null +++ b/tests/execution/test_public_api.py @@ -0,0 +1,153 @@ +""" +Tests for public ComfyAPI and ComfyAPISync functions. + +These tests verify that the public API methods work correctly in both sync and async contexts, +ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly +handles string annotations from 'from __future__ import annotations'. +""" + +import pytest +import time +import subprocess +import torch +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +@pytest.mark.execution +class TestPublicAPI: + """Test suite for public ComfyAPI and ComfyAPISync methods.""" + + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start ComfyUI server for testing.""" + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + """Create shared client with connection retry.""" + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + break + except ConnectionRefusedError: + if i == n_tries - 1: + raise + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + """Set test name for each test.""" + shared_client.set_test_name(f"public_api[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + """Create GraphBuilder for each test.""" + yield GraphBuilder(prefix=request.node.name) + + def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestSyncProgressUpdate executes without errors. + + This test validates that api_sync.execution.set_progress() works correctly, + which is the primary code path fixed by adding get_type_hints() to async_to_sync.py. + """ + g = builder + image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + + # Use TestSyncProgressUpdate with short sleep + progress_node = g.node("TestSyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestAsyncProgressUpdate executes without errors. + + This test validates that await api.execution.set_progress() works correctly + in async contexts. + """ + g = builder + image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use TestAsyncProgressUpdate with short sleep + progress_node = g.node("TestAsyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Async progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder): + """Test both sync and async progress updates in same workflow. + + This test ensures that both ComfyAPISync and ComfyAPI can coexist and work + correctly in the same workflow execution. + """ + g = builder + image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use both types of progress nodes + sync_progress = g.node("TestSyncProgressUpdate", + value=image1.out(0), + sleep_seconds=0.3) + async_progress = g.node("TestAsyncProgressUpdate", + value=image2.out(0), + sleep_seconds=0.3) + + # Create outputs + output1 = g.node("SaveImage", images=sync_progress.out(0)) + output2 = g.node("SaveImage", images=async_progress.out(0)) + + # Execute workflow + result = client.run(g) + + # Both should execute successfully + assert result.did_run(sync_progress), "Sync progress node should have executed" + assert result.did_run(async_progress), "Async progress node should have executed" + assert result.did_run(output1), "First output node should have executed" + assert result.did_run(output2), "Second output node should have executed" + + # Verify outputs + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have produced 1 image from sync node" + assert len(images2) == 1, "Should have produced 1 image from async node" From 3bd71554a2df14b862cc5e1e875df37ba24af1ac Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:48:37 +0200 Subject: [PATCH 02/39] fix(api-nodes): edge cases in responses for Gemini models (#10860) --- comfy_api_nodes/apis/gemini_api.py | 6 +++--- comfy_api_nodes/nodes_gemini.py | 21 +++++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 710f173f1..d34590d28 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel): maxOutputTokens: int | None = Field(None, ge=16, le=8192) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) - temperature: float | None = Field(1, ge=0.0, le=2.0) - topK: int | None = Field(40, ge=1) - topP: float | None = Field(0.95, ge=0.0, le=1.0) + temperature: float | None = Field(None, ge=0.0, le=2.0) + topK: int | None = Field(None, ge=1) + topP: float | None = Field(None, ge=0.0, le=1.0) class GeminiImageConfig(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index be752c885..938a20f84 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera List of response parts matching the requested type. """ if response.candidates is None: - if response.promptFeedback.blockReason: + if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" ) - raise NotImplementedError( - "Gemini returned no response candidates. " - "Please report to ComfyUI repository with the example of workflow to reproduce this." + raise ValueError( + "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " + "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] for part in response.candidates[0].content.parts: @@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price - for i in response.usageMetadata.candidatesTokensDetails: - if i.modality == Modality.IMAGE: - final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models - else: - final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.candidatesTokensDetails: + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount if response.usageMetadata.thoughtsTokenCount: final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount return final_price / 1_000_000.0 @@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode): options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], default="auto", tooltip="If set to 'auto', matches your input image's aspect ratio; " - "if no image is provided, generates a 1:1 square.", + "if no image is provided, a 16:9 square is usually generated.", ), IO.Combo.Input( "resolution", From 1286fcfe40b98052e4edbe9a02f12ad89ac74924 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:24:29 +0200 Subject: [PATCH 03/39] add get_frame_count and get_frame_rate methods to VideoInput class (#10851) --- comfy_api/latest/_input/video_types.py | 28 ++++++++ comfy_api/latest/_input_impl/video_types.py | 72 +++++++++++++++++++++ comfy_api_nodes/nodes_topaz.py | 15 ++--- 3 files changed, 106 insertions(+), 9 deletions(-) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index a335df4d0..87c81d73a 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from fractions import Fraction from typing import Optional, Union, IO import io import av @@ -72,6 +73,33 @@ class VideoInput(ABC): frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video. + + Default implementation uses :meth:`get_components`, which may require + loading all frames into memory. File-based implementations should + override this method and use container/stream metadata instead. + + Returns: + Total number of frames as an integer. + """ + return int(self.get_components().images.shape[0]) + + def get_frame_rate(self) -> Fraction: + """ + Returns the frame rate of the video. + + Default implementation materializes the video into memory via + `get_components()`. Subclasses that can inspect the underlying + container (e.g. `VideoFromFile`) should override this with a more + efficient implementation. + + Returns: + Frame rate as a Fraction. + """ + return self.get_components().frame_rate + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index f646504c8..bde37f90a 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -121,6 +121,71 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video without materializing them as + torch tensors. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # 1. Prefer the frames field if available + if video_stream.frames and video_stream.frames > 0: + return int(video_stream.frames) + + # 2. Try to estimate from duration and average_rate using only metadata + if container.duration is not None and video_stream.average_rate: + duration_seconds = float(container.duration / av.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + if ( + getattr(video_stream, "duration", None) is not None + and getattr(video_stream, "time_base", None) is not None + and video_stream.average_rate + ): + duration_seconds = float(video_stream.duration * video_stream.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + # 3. Last resort: decode frames and count them (streaming) + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + + if frame_count == 0: + raise ValueError(f"Could not determine frame count for file '{self.__file}'") + return frame_count + + def get_frame_rate(self) -> Fraction: + """ + Returns the average frame rate of the video using container metadata + without decoding all frames. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # Preferred: use PyAV's average_rate (usually already a Fraction-like) + if video_stream.average_rate: + return Fraction(video_stream.average_rate) + + # Fallback: estimate from frames + duration if available + if video_stream.frames and container.duration: + duration_seconds = float(container.duration / av.time_base) + if duration_seconds > 0: + return Fraction(video_stream.frames / duration_seconds).limit_denominator() + + # Last resort: match get_components_internal default + return Fraction(1) + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). @@ -238,6 +303,13 @@ class VideoFromFile(VideoInput): packet.stream = stream_map[packet.stream] output_container.mux(packet) + def _get_first_video_stream(self, container: InputContainer): + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise ValueError(f"No video stream found in file '{self.__file}'") + return video_stream + + class VideoFromComponents(VideoInput): """ Class representing video input from tensors. diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index 79c7bf43d..f522756e5 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -5,8 +5,7 @@ import aiohttp import torch from typing_extensions import override -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.util import ( ApiEndpoint, @@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, upscaler_enabled: bool, upscaler_model: str, upscaler_resolution: str, @@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if upscaler_enabled is False and interpolation_enabled is False: raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") - src_width, src_height = video.get_dimensions() - video_components = video.get_components() - src_frame_rate = int(video_components.frame_rate) - duration_sec = video.get_duration() - estimated_frames = int(duration_sec * src_frame_rate) validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() src_video_stream = video.get_stream_source() target_width = src_width target_height = src_height @@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode): container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), - frameCount=estimated_frames, + frameCount=video.get_frame_count(), frameRate=src_frame_rate, resolution=topaz_api.Resolution(width=src_width, height=src_height), ), From 3d1fdaf9f448b34e4eba68bfd8e8de373ec0d22d Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:30:40 +0800 Subject: [PATCH 04/39] block info (#10843) --- comfy/ldm/chroma/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 67bf70eb1..a72f8cc47 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -179,7 +179,10 @@ class Chroma(nn.Module): pe = self.pe_embedder(ids) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -222,7 +225,10 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: From 6a6d456c88723538e3d0e5e942f78109ece5b73d Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:38:38 +0800 Subject: [PATCH 05/39] block info (#10842) --- comfy/ldm/qwen_image/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 427ea19c1..8c75670cd 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} From b2ef58e2b17e73ca8cd376a1cdc976518ebbc168 Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Tue, 25 Nov 2025 02:40:09 +0800 Subject: [PATCH 06/39] block info (#10844) --- comfy/ldm/hunyuan_video/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f75c6e0e1..2749c53f5 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module): attn_mask = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module): img = torch.cat((img, txt), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} From 22a2644e57530ee40e13486ccd7c953b87072093 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:45:54 -0800 Subject: [PATCH 07/39] Bump transformers version in requirements.txt (#10869) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8e308cd6c..b7014f956 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ torchvision torchaudio numpy>=1.25.0 einops -transformers>=4.37.2 +transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 From 25022e0b0965975b35bcaf28b153184d60a4f9de Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 22:48:53 -0800 Subject: [PATCH 08/39] Cleanup and fix issues with text encoder quants. (#10872) --- comfy/model_patcher.py | 3 +- comfy/ops.py | 168 +++++++++--------- comfy/quant_ops.py | 12 ++ comfy/sd.py | 9 +- comfy/sd1_clip.py | 18 +- comfy/text_encoders/hunyuan_video.py | 3 + .../comfy_quant/test_mixed_precision.py | 17 +- 7 files changed, 128 insertions(+), 102 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cf1b0d441..6551ced5a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -231,7 +231,6 @@ class ModelPatcher: self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update @@ -286,7 +285,7 @@ class ModelPatcher: return self.model.lowvram_patch_counter def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..af185ec24 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -540,113 +540,115 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== from .quant_ops import QuantizedTensor, QUANT_ALGOS -class MixedPrecisionOps(disable_weight_init): - _layer_quant_config = {} - _compute_dtype = torch.bfloat16 - class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__() +def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _layer_quant_config = layer_quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm - self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) - else: - self.register_parameter("bias", None) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} - self.tensor_class = None + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) - def reset_parameters(self): - return None + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def reset_parameters(self): + return None - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): - manually_loaded_keys = [weight_key] + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") - if layer_name not in MixedPrecisionOps._layer_quant_config: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") + manually_loaded_keys = [weight_key] - qconfig = QUANT_ALGOS[quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - weight_scale_key = f"{prefix}weight_scale" - layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } - if layout_params['scale'] is not None: - manually_loaded_keys.append(weight_scale_key) + qconfig = QUANT_ALGOS[quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), - requires_grad=False - ) + weight_scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(weight_scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(weight_scale_key) - for param_name in qconfig["parameters"]: - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + requires_grad=False + ) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def _forward(self, input, weight, bias): - return torch.nn.functional.linear(input, weight, bias) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._forward(input, weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) - def forward(self, input, *args, **kwargs): - run_every_op() + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) - if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + def forward(self, input, *args, **kwargs): + run_every_op() + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return MixedPrecisionOps + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece..905b4729e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada..b6df0bd61 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -917,7 +917,12 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + if metadata is not None: + quant_metadata = metadata.get("_quantization_metadata", None) + if quant_metadata is not None: + sd["_quantization_metadata"] = quant_metadata + clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1142,6 +1147,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: + if "_quantization_metadata" in c: + c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7..8f509bab1 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -109,13 +109,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): operations = model_options.get("custom_operations", None) scaled_fp8 = None + quantization_metadata = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + layer_quant_config = None + if quantization_metadata is not None: + layer_quant_config = json.loads(quantization_metadata).get("layers", None) + + if layer_quant_config is not None: + operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 557094f49..0110517bb 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""): if scaled_fp8_key in state_dict: out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + if "_quantization_metadata" in state_dict: + out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + return out diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index f8d1fd04e..63361309f 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): """Test that model with no quantization works normally""" - # Configure no quantization - ops.MixedPrecisionOps._layer_quant_config = {} - # Create model - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops({})) # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) @@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict state_dict = { @@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False) From b6805429b9c2f3aa919035bea849ecd1de3ac8e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:48:20 -0800 Subject: [PATCH 09/39] Allow pinning quantized tensors. (#10873) --- comfy/model_management.py | 6 +++++- comfy/quant_ops.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a21df54b3..a9327ac80 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1098,13 +1098,14 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) +PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if type(tensor) is not torch.nn.parameter.Parameter: + if type(tensor).__name__ not in PINNING_ALLOWED_TYPES: return False if not is_device_cpu(tensor.device): @@ -1124,6 +1125,9 @@ def pin_memory(tensor): return False ptr = tensor.data_ptr() + if ptr == 0: + return False + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 905b4729e..e938144a7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor): new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self): + return self._qdata.is_contiguous() # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) From acfaa5c4a132e1c01bc9d94e76b0d667c899bfd1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 24 Nov 2025 23:55:49 -0800 Subject: [PATCH 10/39] Don't try fp8 matrix mult in quantized ops if not supported by hardware. (#10874) --- comfy/ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index af185ec24..785aa1c9f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -646,11 +646,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) From 015a0599d08f1072155b9213d488b73e502fea3c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 00:23:19 -0800 Subject: [PATCH 11/39] I found a case where this is needed (#10875) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index e938144a7..0c16bcf8d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -405,8 +405,8 @@ class TensorCoreFP8Layout(QuantizedLayout): tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - # lp_amax = torch.finfo(dtype).max - # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { From 6b573ae0cb11000a0330a35d9e31917c22c874a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 07:50:19 -0800 Subject: [PATCH 12/39] Flux 2 (#10879) --- comfy/latent_formats.py | 9 +++ comfy/ldm/flux/layers.py | 90 +++++++++++++++++++-------- comfy/ldm/flux/model.py | 80 ++++++++++++++++++------ comfy/ldm/models/autoencoder.py | 42 +++++++++++++ comfy/model_base.py | 23 +++++-- comfy/model_detection.py | 50 +++++++++++---- comfy/sd.py | 26 +++++++- comfy/supported_models.py | 34 +++++++++- comfy/text_encoders/flux.py | 107 +++++++++++++++++++++++++++++++- comfy/text_encoders/llama.py | 31 +++++++++ comfy_extras/nodes_flux.py | 80 +++++++++++++++++++++++- nodes.py | 2 +- 12 files changed, 506 insertions(+), 68 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 204fc048d..e98c7d6d8 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -178,6 +178,15 @@ class Flux(SD3): def process_out(self, latent): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 23150a712..2472ab79c 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module): class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -98,11 +98,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -129,8 +129,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -142,27 +152,44 @@ class DoubleStreamBlock(nn.Module): self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.img_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) if self.modulation: self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + if mlp_silu_act: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + self.txt_mlp = nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): @@ -246,6 +273,8 @@ class SingleStreamBlock(nn.Module): mlp_ratio: float = 4.0, qk_scale: float = None, modulation=True, + mlp_silu_act=False, + bias=True, dtype=None, device=None, operations=None @@ -257,17 +286,24 @@ class SingleStreamBlock(nn.Module): self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") if modulation: self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) else: @@ -279,7 +315,7 @@ class SingleStreamBlock(nn.Module): else: mod = vec - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) del qkv @@ -298,11 +334,11 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b9d36f202..1a24e6d95 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -15,6 +15,7 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation ) @dataclass @@ -33,6 +34,11 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 class Flux(nn.Module): @@ -58,13 +64,17 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -73,6 +83,9 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -81,13 +94,30 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -103,9 +133,6 @@ class Flux(nn.Module): attn_mask: Tensor = None, ) -> Tensor: - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -118,9 +145,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) @@ -177,6 +212,9 @@ class Flux(nn.Module): img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): @@ -207,7 +245,7 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -234,10 +272,10 @@ class Flux(nn.Module): h_offset += rope_options.get("shift_y", 0.0) w_offset += rope_options.get("shift_x", 0.0) - img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -259,10 +297,10 @@ class Flux(nn.Module): h = 0 w = 0 index = 0 - ref_latents_method = kwargs.get("ref_latents_method", "offset") + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: if ref_latents_method == "index": - index += 1 + index += self.params.ref_index_scale h_offset = 0 w_offset = 0 elif ref_latents_method == "uxo": @@ -286,7 +324,11 @@ class Flux(nn.Module): img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.axes_dim) == 4: # Flux 2 + txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 611d36a1b..4f50810dc 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.modules.ema import LitEma import comfy.ops +from einops import rearrange +import comfy.model_management class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): @@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: params = super().get_autoencoder_params() return params @@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine): z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index e14b552c5..cad79ecbd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,12 +898,13 @@ class Flux(BaseModel): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -928,6 +929,16 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0131ca25a..b2ba1459d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -200,26 +200,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["vec_in_dim"] = None + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 diff --git a/comfy/sd.py b/comfy/sd.py index b6df0bd61..14dd8944c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -356,7 +356,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32: + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -382,6 +382,17 @@ class VAE: self.upscale_ratio = 4 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + if 'post_quant_conv.weight' in sd: self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) else: @@ -940,6 +951,8 @@ class TEModel(Enum): QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -972,6 +985,13 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -1086,6 +1106,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2e64b85e8..8fe8e63f6 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -741,6 +741,37 @@ class FluxSchnell(Flux): out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1422,6 +1453,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] + models += [SVD_img2vid] diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d61ef6668..8dbbca16e 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,10 +1,13 @@ from comfy import sd1_clip import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip +import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast import torch import os +import json +import base64 class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -68,3 +71,105 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + try: + from transformers.integrations.mistral import MistralConverter + except ModuleNotFoundError: + from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + +def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Flux2TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index feb44bbb0..749ff581b 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -34,6 +34,28 @@ class Llama2Config: rope_scale = None final_norm: bool = True +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -465,6 +487,15 @@ class Llama2(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index ce1b2e89f..d9c4bba81 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -2,7 +2,10 @@ import node_helpers import comfy.utils from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import math +import nodes class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -30,6 +33,27 @@ class CLIPTextEncodeFlux(io.ComfyNode): encode = execute # TODO: remove +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) class FluxGuidance(io.ComfyNode): @classmethod @@ -154,6 +178,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): append = execute # TODO: remove +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + class FluxExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -163,6 +239,8 @@ class FluxExtension(ComfyExtension): FluxDisableGuidance, FluxKontextImageScale, FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, ] diff --git a/nodes.py b/nodes.py index f023ae3b6..f4835c02e 100644 --- a/nodes.py +++ b/nodes.py @@ -929,7 +929,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From 5c7b08ca58f5412b3a814b374793cacdb5b5f0a7 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:09:07 +0200 Subject: [PATCH 13/39] [API Nodes] add Flux.2 Pro node (#10880) --- comfy_api_nodes/apis/bfl_api.py | 28 +++- comfy_api_nodes/nodes_bfl.py | 238 +++++++++++++++---------------- comfy_api_nodes/util/__init__.py | 2 + 3 files changed, 143 insertions(+), 125 deletions(-) diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0fc8c0607..d8d3557b3 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class Flux2ProGenerateRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + seed: int | None = Field(None) + prompt_upsampling: bool | None = Field(None) + input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + safety_tolerance: int | None = Field( + 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 + ) + output_format: str | None = Field( + "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." + ) + + class BFLFluxKontextProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for what you wannt to edit.') input_image: Optional[str] = Field(None, description='Image to edit in base64 format') @@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') + id: str = Field(..., description="The unique identifier for the generation task.") + polling_url: str = Field(..., description="URL to poll for the generation result.") + cost: float | None = Field(None, description="Price in cents") class BFLStatus(str, Enum): diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 1740fb377..8826dea0c 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,7 +1,7 @@ from inspect import cleandoc -from typing import Optional import torch +from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension @@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, BFLStatus, + Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, + get_number_of_images, poll_op, resize_mask_to_image, sync_op, @@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: Optional[torch.Tensor] = None, + image_prompt: torch.Tensor | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor] = None, + input_image: torch.Tensor | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and resolution. - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProImageNode", - display_name="Flux 1.1 [pro] Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Int.Input( - "width", - default=1024, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "height", - default=768, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - IO.Image.Input( - "image_prompt", - optional=True, - ), - # "image_prompt_strength": ( - # IO.FLOAT, - # { - # "default": 0.1, - # "min": 0.0, - # "max": 1.0, - # "step": 0.01, - # "tooltip": "Blend between the prompt and the image prompt.", - # }, - # ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - prompt: str, - prompt_upsampling, - width: int, - height: int, - seed=0, - image_prompt=None, - # image_prompt_strength=0.1, - ) -> IO.NodeOutput: - image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) - initial_response = await sync_op( - cls, - ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1/generate", - method="POST", - ), - response_model=BFLFluxProGenerateResponse, - data=BFLFluxProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - width=width, - height=height, - seed=seed, - image_prompt=image_prompt, - ), - ) - response = await poll_op( - cls, - ApiEndpoint(initial_response.polling_url), - response_model=BFLFluxStatusResponse, - status_extractor=lambda r: r.status, - progress_extractor=lambda r: r.progress, - completed_statuses=[BFLStatus.ready], - failed_statuses=[ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - BFLStatus.error, - BFLStatus.task_not_found, - ], - queued_statuses=[], - ) - return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) - - class FluxProExpandNode(IO.ComfyNode): """ Outpaints image based on prompt. @@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2ProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Flux2ProImageNode", + display_name="Flux.2 [pro] Image", + category="api node/image/BFL", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or edit", + ), + IO.Int.Input( + "width", + default=1024, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "height", + default=768, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", + ), + IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + width: int, + height: int, + seed: int, + prompt_upsampling: bool, + images: torch.Tensor | None = None, + ) -> IO.NodeOutput: + reference_images = {} + if images is not None: + if get_number_of_images(images) > 9: + raise ValueError("The current maximum number of supported images is 9.") + for image_index in range(images.shape[0]): + key_name = f"input_image_{image_index + 1}" if image_index else "input_image" + reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=Flux2ProGenerateRequest( + prompt=prompt, + width=width, + height=height, + seed=seed, + prompt_upsampling=prompt_upsampling, + **reference_images, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FluxProUltraImageNode, - # FluxProImageNode, FluxKontextProImageNode, FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + Flux2ProImageNode, ] diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 21013b591..80292fb3c 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -36,6 +36,7 @@ from .upload_helpers import ( upload_video_to_comfyapi, ) from .validation_utils import ( + get_image_dimensions, get_number_of_images, validate_aspect_ratio_string, validate_audio_duration, @@ -82,6 +83,7 @@ __all__ = [ "trim_video", "video_to_base64_string", # Validation utilities + "get_image_dimensions", "get_number_of_images", "validate_aspect_ratio_string", "validate_audio_duration", From af81cb962d9dd283ddb551962cc223b5a186a1ce Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 08:40:32 -0800 Subject: [PATCH 14/39] Add Flux 2 support to README. (#10882) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 28beec427..b9300ab07 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) + - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From 828b1b9953175b6df79459f417d1032869d0b46a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Nov 2025 12:40:58 -0500 Subject: [PATCH 15/39] ComfyUI version v0.3.72 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b4655d553..dac038c26 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.71" +__version__ = "0.3.72" diff --git a/pyproject.toml b/pyproject.toml index 280dbaf53..75df8fb7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.71" +version = "0.3.72" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From dff996ca39d86265bbabf15e666484e051f0b3d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:30:24 -0800 Subject: [PATCH 16/39] Fix crash. (#10885) --- comfy/text_encoders/flux.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 8dbbca16e..024504a5b 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -87,6 +87,7 @@ def load_mistral_tokenizer(data): vocab = {} max_vocab = mistral_vocab["config"]["default_vocab_size"] + max_vocab -= len(mistral_vocab["special_tokens"]) for w in mistral_vocab["vocab"]: r = w["rank"] From 18b79acba95d44b4ea00bbbfc1856bc71bd58841 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 26 Nov 2025 03:58:21 +0800 Subject: [PATCH 17/39] Update workflow templates to v0.7.20 (#10883) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b7014f956..5f20816d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.30.6 -comfyui-workflow-templates==0.7.9 +comfyui-workflow-templates==0.7.20 comfyui-embedded-docs==0.3.1 torch torchsde From d196a905bb379a6d800d0c13f9b4fdea3965311a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 11:58:39 -0800 Subject: [PATCH 18/39] Lower vram usage for flux 2 text encoder. (#10887) --- comfy/sd1_clip.py | 7 ++++--- comfy/text_encoders/flux.py | 4 ++-- comfy/text_encoders/llama.py | 12 +++++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 8f509bab1..0fc9ab3db 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() - assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") @@ -164,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if self.layer == "all": + if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -266,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - if self.layer == "all": + if isinstance(self.layer, list): + intermediate_output = self.layer + elif self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 024504a5b..99f4812bb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -138,7 +138,7 @@ class Flux2Tokenizer(sd1_clip.SD1Tokenizer): return tokens class Mistral3_24BModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = {} num_layers = model_options.get("num_layers", None) if num_layers is not None: @@ -154,7 +154,7 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): def encode_token_weights(self, token_weight_pairs): out, pooled, extra = super().encode_token_weights(token_weight_pairs) - out = torch.stack((out[:, 10], out[:, 20], out[:, 30]), dim=1) + out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1) out = out.movedim(1, 2) out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 749ff581b..d47ed27bc 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -434,8 +434,12 @@ class Llama2_(nn.Module): intermediate = None all_intermediate = None + only_layers = None if intermediate_output is not None: - if intermediate_output == "all": + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": all_intermediate = [] intermediate_output = None elif intermediate_output < 0: @@ -443,7 +447,8 @@ class Llama2_(nn.Module): for i, layer in enumerate(self.layers): if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -457,7 +462,8 @@ class Llama2_(nn.Module): x = self.norm(x) if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) From 0c18842acbdf546883b08808dd9feea7605d7649 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Nov 2025 14:59:37 -0500 Subject: [PATCH 19/39] ComfyUI v0.3.73 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index dac038c26..f8818838e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.72" +__version__ = "0.3.73" diff --git a/pyproject.toml b/pyproject.toml index 75df8fb7c..7e4bac12d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.72" +version = "0.3.73" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From e9aae31fa241a6a63a368800146ea91629d4e8c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:41:45 -0800 Subject: [PATCH 20/39] Z Image model. (#10892) --- comfy/ldm/lumina/model.py | 219 +++++++------------- comfy/ldm/modules/diffusionmodules/mmdit.py | 6 +- comfy/model_base.py | 4 + comfy/model_detection.py | 29 ++- comfy/sd.py | 8 + comfy/text_encoders/llama.py | 31 +++ comfy/text_encoders/z_image.py | 48 +++++ 7 files changed, 196 insertions(+), 149 deletions(-) create mode 100644 comfy/text_encoders/z_image.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b4494a51d..c8643eb82 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension @@ -31,6 +32,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + out_bias: bool = False, operation_settings={}, ): """ @@ -59,7 +61,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -70,35 +72,6 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( self, x: torch.Tensor, @@ -134,8 +107,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: @@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module): norm_eps: float, qk_norm: bool, modulation=True, + z_image_modulation=False, + attn_out_bias=False, operation_settings={}, ) -> None: """ @@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( self, @@ -323,7 +308,7 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}): super().__init__() self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, @@ -340,10 +325,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -373,12 +363,16 @@ class NextDiT(nn.Module): n_heads: int = 32, n_kv_heads: Optional[int] = None, multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, + ffn_dim_multiplier: float = 4.0, norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, image_model=None, device=None, dtype=None, @@ -390,6 +384,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -411,6 +407,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -434,7 +431,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -457,18 +454,24 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -503,108 +506,42 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) - - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) + padded_img_mask = None for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): @@ -627,7 +564,7 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 42f406f1a..0dc8fe789 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size diff --git a/comfy/model_base.py b/comfy/model_base.py index cad79ecbd..cc21b1de9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1114,9 +1114,13 @@ class Lumina2(BaseModel): if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b2ba1459d..7afe4a798 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/sd.py b/comfy/sd.py index 14dd8944c..350fae92b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -52,6 +52,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image import comfy.model_patcher import comfy.lora @@ -953,6 +954,8 @@ class TEModel(Enum): GEMMA_3_4B = 13 MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -985,6 +988,8 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: @@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d47ed27bc..cd4b5f76c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -78,6 +78,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -511,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..bb9273b20 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,48 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_ From 0e24dbb19f34f242edb77c550396cf6806f7b22f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:02:51 -0800 Subject: [PATCH 21/39] Adjustments to Z Image. (#10893) --- comfy/supported_models.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8fe8e63f6..af8120400 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image from . import supported_models_base from . import latent_formats @@ -994,7 +995,7 @@ class Lumina2(supported_models_base.BASE): "shift": 6.0, } - memory_usage_factor = 1.2 + memory_usage_factor = 1.4 unet_extra_config = {} latent_format = latent_formats.Flux @@ -1013,6 +1014,24 @@ class Lumina2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) +class ZImage(Lumina2): + unet_config = { + "image_model": "lumina2", + "dim": 3840, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + memory_usage_factor = 1.7 + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1453,7 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] models += [SVD_img2vid] From bdb10a583f1b1e495ee00dbd1674f11016a6a93e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 21:07:58 -0800 Subject: [PATCH 22/39] Fix loras not working on mixed fp8. (#10899) --- comfy/model_patcher.py | 2 +- comfy/ops.py | 22 +++++++++++++++++++++- comfy/quant_ops.py | 21 ++++++++++++++------- comfy/weight_adapter/lora.py | 1 + 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6551ced5a..73adc7f70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -132,7 +132,7 @@ class LowVramPatch: def __call__(self, weight): intermediate_dtype = weight.dtype if self.convert_func is not None: - weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + weight = self.convert_func(weight, inplace=False) if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 diff --git a/comfy/ops.py b/comfy/ops.py index 785aa1c9f..a0ff4e8f1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if weight_has_function or weight.dtype != dtype: with wf_context: weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() for f in s.weight_function: weight = f(weight) @@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight else: - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) @@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) + + def convert_weight(self, weight, inplace=False, **kwargs): + if isinstance(weight, QuantizedTensor): + return weight.dequantize() + else: + return weight + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): + if getattr(self, 'layout_type', None) is not None: + weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + else: + weight = weight.to(self.weight.dtype) + if return_weight: + return weight + + assert inplace_update is False # TODO: eventually remove the inplace_update stuff + self.weight = torch.nn.Parameter(weight, requires_grad=False) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 0c16bcf8d..d2f3e7397 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,7 @@ import torch import logging from typing import Tuple, Dict +import comfy.float _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} @@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout): - orig_dtype: Original dtype before quantization (for casting back) """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype if scale is None: @@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) - # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) + + if stochastic_rounding > 0: + tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + else: + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) + tensor = tensor.to(dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, 'orig_dtype': orig_dtype } - return qdata, layout_params + return tensor, layout_params @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 4db004e50..3cc60bb1b 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase): lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) + del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, From 90b3995ec842335e44d70e0521ff6ff6c3ff9aaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 00:34:15 -0500 Subject: [PATCH 23/39] ComfyUI v0.3.74 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index f8818838e..b565c7367 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.73" +__version__ = "0.3.74" diff --git a/pyproject.toml b/pyproject.toml index 7e4bac12d..ccf0fcdb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.73" +version = "0.3.74" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 58b85746618e2bc2dd32024c89403926aad59f48 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 23:36:19 -0800 Subject: [PATCH 24/39] Fix Flux2 reference image mem estimation. (#10905) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index cc21b1de9..9b76c285e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -926,7 +926,7 @@ class Flux(BaseModel): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class Flux2(Flux): From 8402c8700a29a97bc5d706d6a0b14c41bc2c2d8a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 02:41:13 -0500 Subject: [PATCH 25/39] ComfyUI version v0.3.75 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b565c7367..fa4b4f4b0 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.74" +__version__ = "0.3.75" diff --git a/pyproject.toml b/pyproject.toml index ccf0fcdb9..9009e65fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.74" +version = "0.3.75" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f16219e3aadcb7a301a1a313ab8989c3ebe53764 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:00:43 -0800 Subject: [PATCH 26/39] Add cheap latent preview for flux 2. (#10907) Thank you to the person who calculated them. You saved me a percent of my time. --- comfy/latent_formats.py | 40 ++++++++++++++++++++++++++++++++++++++++ latent_preview.py | 7 +++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index e98c7d6d8..8e110f45d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -6,6 +6,7 @@ class LatentFormat: latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None + latent_rgb_factors_reshape = None taesd_decoder_name = None def process_in(self, latent): @@ -181,6 +182,45 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + def __init__(self): + self.latent_rgb_factors =[ + [0.0058, 0.0113, 0.0073], + [0.0495, 0.0443, 0.0836], + [-0.0099, 0.0096, 0.0644], + [0.2144, 0.3009, 0.3652], + [0.0166, -0.0039, -0.0054], + [0.0157, 0.0103, -0.0160], + [-0.0398, 0.0902, -0.0235], + [-0.0052, 0.0095, 0.0109], + [-0.3527, -0.2712, -0.1666], + [-0.0301, -0.0356, -0.0180], + [-0.0107, 0.0078, 0.0013], + [0.0746, 0.0090, -0.0941], + [0.0156, 0.0169, 0.0070], + [-0.0034, -0.0040, -0.0114], + [0.0032, 0.0181, 0.0080], + [-0.0939, -0.0008, 0.0186], + [0.0018, 0.0043, 0.0104], + [0.0284, 0.0056, -0.0127], + [-0.0024, -0.0022, -0.0030], + [0.1207, -0.0026, 0.0065], + [0.0128, 0.0101, 0.0142], + [0.0137, -0.0072, -0.0007], + [0.0095, 0.0092, -0.0059], + [0.0000, -0.0077, -0.0049], + [-0.0465, -0.0204, -0.0312], + [0.0095, 0.0012, -0.0066], + [0.0290, -0.0034, 0.0025], + [0.0220, 0.0169, -0.0048], + [-0.0332, -0.0457, -0.0468], + [-0.0085, 0.0389, 0.0609], + [-0.0076, 0.0003, -0.0043], + [-0.0111, -0.0460, -0.0614], + ] + + self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + def process_in(self, latent): return latent diff --git a/latent_preview.py b/latent_preview.py index 95d3cb733..ddf6dcf49 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = None if latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + self.latent_rgb_factors_reshape = latent_rgb_factors_reshape def decode_latent_to_preview(self, x0): + if self.latent_rgb_factors_reshape is not None: + x0 = self.latent_rgb_factors_reshape(x0) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) @@ -85,7 +88,7 @@ def get_previewer(device, latent_format): if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape) return previewer def prepare_callback(model, steps, x0_output_dict=None): From 8938aa3f3064415758fa8f3a628476535a676183 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:14:02 +0200 Subject: [PATCH 27/39] add Veo3 First-Last-Frame node (#10878) --- comfy_api_nodes/apis/veo_api.py | 38 +++----- comfy_api_nodes/nodes_veo2.py | 155 ++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 25 deletions(-) diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py index a55137afb..8328d1aa4 100644 --- a/comfy_api_nodes/apis/veo_api.py +++ b/comfy_api_nodes/apis/veo_api.py @@ -1,34 +1,21 @@ -from typing import Optional, Union -from enum import Enum +from typing import Optional from pydantic import BaseModel, Field -class Image2(BaseModel): - bytesBase64Encoded: str - gcsUri: Optional[str] = None - mimeType: Optional[str] = None +class VeoRequestInstanceImage(BaseModel): + bytesBase64Encoded: str | None = Field(None) + gcsUri: str | None = Field(None) + mimeType: str | None = Field(None) -class Image3(BaseModel): - bytesBase64Encoded: Optional[str] = None - gcsUri: str - mimeType: Optional[str] = None - - -class Instance1(BaseModel): - image: Optional[Union[Image2, Image3]] = Field( - None, description='Optional image to guide video generation' - ) +class VeoRequestInstance(BaseModel): + image: VeoRequestInstanceImage | None = Field(None) + lastFrame: VeoRequestInstanceImage | None = Field(None) prompt: str = Field(..., description='Text description of the video') -class PersonGeneration1(str, Enum): - ALLOW = 'ALLOW' - BLOCK = 'BLOCK' - - -class Parameters1(BaseModel): +class VeoRequestParameters(BaseModel): aspectRatio: Optional[str] = Field(None, examples=['16:9']) durationSeconds: Optional[int] = None enhancePrompt: Optional[bool] = None @@ -37,17 +24,18 @@ class Parameters1(BaseModel): description='Generate audio for the video. Only supported by veo 3 models.', ) negativePrompt: Optional[str] = None - personGeneration: Optional[PersonGeneration1] = None + personGeneration: str | None = Field(None, description="ALLOW or BLOCK") sampleCount: Optional[int] = None seed: Optional[int] = None storageUri: Optional[str] = Field( None, description='Optional Cloud Storage URI to upload the video' ) + resolution: str | None = Field(None) class VeoGenVidRequest(BaseModel): - instances: Optional[list[Instance1]] = None - parameters: Optional[Parameters1] = None + instances: list[VeoRequestInstance] | None = Field(None) + parameters: VeoRequestParameters | None = Field(None) class VeoGenVidResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index d37e9e9b4..a54dc13ab 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,6 +1,7 @@ import base64 from io import BytesIO +import torch from typing_extensions import override from comfy_api.input_impl.video_types import VideoFromFile @@ -10,6 +11,9 @@ from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollResponse, VeoGenVidRequest, VeoGenVidResponse, + VeoRequestInstance, + VeoRequestInstanceImage, + VeoRequestParameters, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -346,12 +350,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ) +class Veo3FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3FirstLastFrameNode", + display_name="Google Veo 3 First-Last-Frame to Video", + category="api node/video/Veo", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.Int.Input( + "duration", + default=8, + min=4, + max=8, + step=2, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation", + ), + IO.Image.Input("first_frame", tooltip="Start frame"), + IO.Image.Input("last_frame", tooltip="End frame"), + IO.Combo.Input( + "model", + options=["veo-3.1-generate", "veo-3.1-fast-generate"], + default="veo-3.1-fast-generate", + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Generate audio for the video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + model: str, + generate_audio: bool, + ): + model = MODELS_MAP[model] + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=[ + VeoRequestInstance( + prompt=prompt, + image=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png" + ), + lastFrame=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png" + ), + ), + ], + parameters=VeoRequestParameters( + aspectRatio=aspect_ratio, + personGeneration="ALLOW", + durationSeconds=duration, + enhancePrompt=True, # cannot be False for Veo3 + seed=seed, + generateAudio=generate_audio, + negativePrompt=negative_prompt, + resolution=resolution, + ), + ), + ) + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest( + operationName=initial_response.name, + ), + poll_interval=5.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + + class VeoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ VeoVideoGenerationNode, Veo3VideoGenerationNode, + Veo3FirstLastFrameNode, ] From 1105e0d139001ad602d0f883406bfce41e54ae67 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 19:23:14 +0200 Subject: [PATCH 28/39] improve UX for batch uploads in upload_images_to_comfyapi (#10913) --- comfy_api_nodes/util/upload_helpers.py | 43 +++++++++++++------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 632450d9b..b9019841f 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,7 +4,7 @@ import logging import time import uuid from io import BytesIO -from typing import Optional, Union +from typing import Optional from urllib.parse import urlparse import aiohttp @@ -48,8 +48,9 @@ async def upload_images_to_comfyapi( image: torch.Tensor, *, max_images: int = 8, - mime_type: Optional[str] = None, - wait_label: Optional[str] = "Uploading", + mime_type: str | None = None, + wait_label: str | None = "Uploading", + show_batch_index: bool = True, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. @@ -59,11 +60,18 @@ async def upload_images_to_comfyapi( download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 + num_to_upload = min(batch_len, max_images) + batch_start_ts = time.monotonic() - for idx in range(min(batch_len, max_images)): + for idx in range(num_to_upload): tensor = image[idx] if is_batch else image img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + + effective_label = wait_label + if wait_label and show_batch_index and num_to_upload > 1: + effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" + + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) download_urls.append(url) return download_urls @@ -126,8 +134,9 @@ async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, filename: str, - upload_mime_type: Optional[str], - wait_label: Optional[str] = "Uploading", + upload_mime_type: str | None, + wait_label: str | None = "Uploading", + progress_origin_ts: float | None = None, ) -> str: """Uploads a single file to ComfyUI API and returns its download URL.""" if upload_mime_type is None: @@ -148,6 +157,7 @@ async def upload_file_to_comfyapi( file_bytes_io, content_type=upload_mime_type, wait_label=wait_label, + progress_origin_ts=progress_origin_ts, ) return create_resp.download_url @@ -155,27 +165,18 @@ async def upload_file_to_comfyapi( async def upload_file( cls: type[IO.ComfyNode], upload_url: str, - file: Union[BytesIO, str], + file: BytesIO | str, *, - content_type: Optional[str] = None, + content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, - wait_label: Optional[str] = None, + wait_label: str | None = None, + progress_origin_ts: float | None = None, ) -> None: """ Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. - Args: - cls: Node class (provides auth context + UI progress hooks). - upload_url: Pre-signed PUT URL. - file: BytesIO or path string. - content_type: Explicit MIME type. If None, we *suppress* Content-Type. - max_retries: Maximum retry attempts. - retry_delay: Initial delay in seconds. - retry_backoff: Exponential backoff factor. - wait_label: Progress label shown in Comfy UI. - Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception """ @@ -198,7 +199,7 @@ async def upload_file( attempt = 0 delay = retry_delay - start_ts = time.monotonic() + start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() op_uuid = uuid.uuid4().hex[:8] while True: attempt += 1 From 8908ee262862f1252d1363d55c59872fb3361066 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 26 Nov 2025 20:38:30 +0200 Subject: [PATCH 29/39] fix(gemini): use first 10 images as fileData (URLs) and remaining images as inline base64 (#10918) --- comfy_api_nodes/apis/gemini_api.py | 6 ++++ comfy_api_nodes/nodes_gemini.py | 55 ++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index d34590d28..a380ecc86 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel): mimeType: GeminiMimeType | None = Field(None) +class GeminiFileData(BaseModel): + fileUri: str | None = Field(None) + mimeType: GeminiMimeType | None = Field(None) + + class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) + fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 938a20f84..976d9c225 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -20,6 +20,7 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis.gemini_api import ( GeminiContent, + GeminiFileData, GeminiGenerateContentRequest, GeminiGenerateContentResponse, GeminiImageConfig, @@ -38,6 +39,7 @@ from comfy_api_nodes.util import ( get_number_of_images, sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, validate_string, video_to_base64_string, ) @@ -68,24 +70,43 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ +async def create_image_parts( + cls: type[IO.ComfyNode], + images: torch.Tensor, + image_limit: int = 0, +) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0)) + if image_limit < 0: + raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") + total_images = get_number_of_images(images) + if total_images <= 0: + raise ValueError("No images provided to create_image_parts; at least one image is required.") + + # If image_limit == 0 --> use all images; otherwise clamp to image_limit. + effective_max = total_images if image_limit == 0 else min(total_images, image_limit) + + # Number of images we'll send as URLs (fileData) + num_url_images = min(effective_max, 10) # Vertex API max number of image links + reference_images_urls = await upload_images_to_comfyapi( + cls, + images, + max_images=num_url_images, + ) + for reference_image_url in reference_images_urls: + image_parts.append( + GeminiPart( + fileData=GeminiFileData( + mimeType=GeminiMimeType.image_png, + fileUri=reference_image_url, + ) + ) + ) + for idx in range(num_url_images, effective_max): image_parts.append( GeminiPart( inlineData=GeminiInlineData( mimeType=GeminiMimeType.image_png, - data=image_as_b64, + data=tensor_to_base64_string(images[idx]), ) ) ) @@ -338,8 +359,7 @@ class GeminiNode(IO.ComfyNode): # Add other modal parts if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if audio is not None: parts.extend(cls.create_audio_parts(audio)) if video is not None: @@ -562,8 +582,7 @@ class GeminiImage(IO.ComfyNode): image_config = GeminiImageConfig(aspectRatio=aspect_ratio) if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) @@ -702,7 +721,7 @@ class GeminiImage2(IO.ComfyNode): if images is not None: if get_number_of_images(images) > 14: raise ValueError("The current maximum number of supported images is 14.") - parts.extend(create_image_parts(images)) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) From 234c3dc85f7e61a537bbf6d8999c5880c5e0b746 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 26 Nov 2025 11:58:08 -0800 Subject: [PATCH 30/39] Bump frontend to 1.32.9 (#10867) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5f20816d6..9291552d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.30.6 +comfyui-frontend-package==1.32.9 comfyui-workflow-templates==0.7.20 comfyui-embedded-docs==0.3.1 torch From 58c6ed541d5aaf6d9b12f63bc23c33164e1cf7a3 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Wed, 26 Nov 2025 14:58:27 -0500 Subject: [PATCH 31/39] Merge 3d animation node (#10025) --- comfy_extras/nodes_load_3d.py | 110 +++++++--------------------------- 1 file changed, 23 insertions(+), 87 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 899608149..54c66ef68 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -7,6 +7,10 @@ from comfy_api.input_impl import VideoFromFile from pathlib import Path +from PIL import Image +import numpy as np + +import uuid def normalize_path(path): return path.replace('\\', '/') @@ -34,58 +38,6 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): - image_path = folder_paths.get_annotated_filepath(image['image']) - mask_path = folder_paths.get_annotated_filepath(image['mask']) - normal_path = folder_paths.get_annotated_filepath(image['normal']) - lineart_path = folder_paths.get_annotated_filepath(image['lineart']) - - load_image_node = nodes.LoadImage() - output_image, ignore_mask = load_image_node.load_image(image=image_path) - ignore_image, output_mask = load_image_node.load_image(image=mask_path) - normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - - video = None - - if image['recording'] != "": - recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - - video = VideoFromFile(recording_video_path) - - return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video - -class Load3DAnimation(): - @classmethod - def INPUT_TYPES(s): - input_dir = os.path.join(folder_paths.get_input_directory(), "3d") - - os.makedirs(input_dir, exist_ok=True) - - input_path = Path(input_dir) - base_path = Path(folder_paths.get_input_directory()) - - files = [ - normalize_path(str(file_path.relative_to(base_path))) - for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} - ] - - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D_ANIMATION", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") @@ -120,7 +72,8 @@ class Preview3D(): "model_file": ("STRING", {"default": "", "multiline": False}), }, "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) + "camera_info": ("LOAD3D_CAMERA", {}), + "bg_image": ("IMAGE", {}) }} OUTPUT_NODE = True @@ -133,50 +86,33 @@ class Preview3D(): def process(self, model_file, **kwargs): camera_info = kwargs.get("camera_info", None) + bg_image = kwargs.get("bg_image", None) + + bg_image_path = None + if bg_image is not None: + + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_array) + + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + + bg_image_path = f"temp/{filename}" return { "ui": { - "result": [model_file, camera_info] - } - } - -class Preview3DAnimation(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] + "result": [model_file, camera_info, bg_image_path] } } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, - "Load3DAnimation": Load3DAnimation, "Preview3D": Preview3D, - "Preview3DAnimation": Preview3DAnimation } NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D", - "Load3DAnimation": "Load 3D - Animation", - "Preview3D": "Preview 3D", - "Preview3DAnimation": "Preview 3D - Animation" + "Load3D": "Load 3D & Animation", + "Preview3D": "Preview 3D & Animation", } From 55f654db3ddaf5a10ac6dbe79774c23c350d279d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:16:40 -0800 Subject: [PATCH 32/39] Fix the CSP offline feature. (#10923) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 0fd2e49e3..fca5050bd 100644 --- a/server.py +++ b/server.py @@ -174,7 +174,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" return response return block_external_middleware From dd41b745497cdbbafb0bd745f590726b0e41f9f3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 12:36:38 -0800 Subject: [PATCH 33/39] Add Z Image to readme. (#10924) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b9300ab07..91fb510e1 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) + - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) From d8433c63fdacef24f40da401b02ebba272bf1fbb Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:42:01 +0200 Subject: [PATCH 34/39] chore(api-nodes): remove chat widgets from OpenAI/Gemini nodes (#10861) --- comfy_api_nodes/nodes_gemini.py | 77 +-------------------------------- comfy_api_nodes/nodes_openai.py | 46 ++++---------------- 2 files changed, 11 insertions(+), 112 deletions(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 976d9c225..08f7b0f64 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer """ import base64 -import json import os -import time -import uuid from enum import Enum from io import BytesIO from typing import Literal @@ -43,7 +40,6 @@ from comfy_api_nodes.util import ( validate_string, video_to_base64_string, ) -from server import PromptServer GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB @@ -384,29 +380,6 @@ class GeminiNode(IO.ComfyNode): ) output_text = get_text_from_response(response) - if output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text or "Empty response from Gemini model...") @@ -601,30 +574,7 @@ class GeminiImage(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -744,30 +694,7 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - - output_text = get_text_from_response(response) - if output_text: - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(get_image_from_response(response), output_text) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index acf35d276..c8da5464b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,15 +1,10 @@ from io import BytesIO -from typing import Optional, Union -import json import os -import time -import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from server import PromptServer import folder_paths import base64 from comfy_api.latest import IO, ComfyExtension @@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode): def create_input_message_contents( cls, prompt: str, - image: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, + image: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [ + content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, persist_context: bool = False, model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, - images: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, - advanced_options: Optional[CreateModelResponseProperties] = None, + images: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, + advanced_options: CreateModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode): status_extractor=lambda response: response.status, completed_statuses=["incomplete", "completed"] ) - output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)) - - # Update history - render_spec = { - "node_id": cls.hidden.unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - return IO.NodeOutput(output_text) + return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) class OpenAIInputFiles(IO.ComfyNode): @@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode): def execute( cls, truncation: bool, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, + instructions: str | None = None, + max_output_tokens: int | None = None, ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. From a2d60aad0f8e03657d501842460123f6eaaf6791 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 27 Nov 2025 00:55:31 +0200 Subject: [PATCH 35/39] convert nodes_customer_sampler.py to V3 schema (#10206) --- comfy_extras/nodes_custom_sampler.py | 1182 ++++++++++++++------------ 1 file changed, 633 insertions(+), 549 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d011f433b..fbb080886 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -3,272 +3,312 @@ import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sa_solver -from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict import latent_preview import torch import comfy.utils import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class BasicScheduler: +class BasicScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, scheduler, steps, denoise): + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = int(steps/denoise) sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class KarrasScheduler: +class KarrasScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExponentialScheduler: + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min): + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) - return (sigmas, ) + return io.NodeOutput(sigmas) -class PolyexponentialScheduler: + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class LaplaceScheduler: + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}), - "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class SDTurboScheduler: +class SDTurboScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 1, "min": 1, "max": 10}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, denoise): + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - return (sigmas, ) + return io.NodeOutput(sigmas) -class BetaSamplingScheduler: + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, alpha, beta): + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) -class VPScheduler: + get_sigmas = execute + +class VPScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, beta_d, beta_min, eps_s): + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) - return (sigmas, ) + return io.NodeOutput(sigmas) -class SplitSigmas: + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "step": ("INT", {"default": 0, "min": 0, "max": 10000}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, step): + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: sigmas1 = sigmas[:step + 1] sigmas2 = sigmas[step:] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class SplitSigmasDenoise: + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, denoise): + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: steps = max(sigmas.shape[-1] - 1, 0) total_steps = round(steps * denoise) sigmas1 = sigmas[:-(total_steps)] sigmas2 = sigmas[-(total_steps + 1):] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class FlipSigmas: + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas): + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: if len(sigmas) == 0: - return (sigmas,) + return io.NodeOutput(sigmas) sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 - return (sigmas,) + return io.NodeOutput(sigmas) -class SetFirstSigma: + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "set_first_sigma" - - def set_first_sigma(self, sigmas, sigma): + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: sigmas = sigmas.clone() sigmas[0] = sigma - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExtendIntermediateSigmas: + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "steps": ("INT", {"default": 2, "min": 1, "max": 100}), - "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}), - "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}), - "spacing": (['linear', 'cosine', 'sine'],), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "extend" - - def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: if start_at_sigma < 0: start_at_sigma = float("inf") @@ -299,27 +339,27 @@ class ExtendIntermediateSigmas: extended_sigmas = torch.FloatTensor(extended_sigmas) - return (extended_sigmas,) + return io.NodeOutput(extended_sigmas) + + extend = execute -class SamplingPercentToSigma: +class SamplingPercentToSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}), - "return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) - RETURN_TYPES = (IO.FLOAT,) - RETURN_NAMES = ("sigma_value",) - CATEGORY = "sampling/custom_sampling/sigmas" - - FUNCTION = "get_sigma" - - def get_sigma(self, model, sampling_percent, return_actual_sigma): + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_val = model_sampling.percent_to_sigma(sampling_percent) if return_actual_sigma: @@ -327,212 +367,234 @@ class SamplingPercentToSigma: sigma_val = model_sampling.sigma_max.item() elif sampling_percent == 1.0: sigma_val = model_sampling.sigma_min.item() - return (sigma_val,) + return io.NodeOutput(sigma_val) + + get_sigma = execute -class KSamplerSelect: +class KSamplerSelect(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, sampler_name): + @classmethod + def execute(cls, sampler_name) -> io.NodeOutput: sampler = comfy.samplers.sampler_object(sampler_name) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_3M_SDE: + get_sampler = execute + +class SamplerDPMPP_3M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, noise_device): + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_3m_sde" else: sampler_name = "dpmpp_3m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2M_SDE: + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"solver_type": (['midpoint', 'heun'], ), - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, eta, s_noise, noise_device): + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerDPMPP_SDE: +class SamplerDPMPP_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, r, noise_device): + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2S_Ancestral: + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestral: + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestralCFGPP: + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), - }} - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler( "euler_ancestral_cfg_pp", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerLMS: + get_sampler = execute + +class SamplerLMS(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 4, "min": 1, "max": 100}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order): + @classmethod + def execute(cls, order) -> io.NodeOutput: sampler = comfy.samplers.ksampler("lms", {"order": order}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMAdaptative: + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 3, "min": 2, "max": 3}), - "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, "s_noise":s_noise }) - return (sampler, ) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerER_SDE(ComfyNodeABC): +class SamplerER_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), - "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), - "eta": ( - IO.FLOAT, - {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, - ), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, max_stage, eta, s_noise): + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): eta = 0 s_noise = 0 @@ -548,32 +610,33 @@ class SamplerER_SDE(ComfyNodeABC): sampler_name = "er_sde" sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerSASolver(ComfyNodeABC): +class SamplerSASolver(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},), - "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},), - "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},), - "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}), - "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}), - "use_pece": (IO.BOOLEAN, {}), - "simple_order_2": (IO.BOOLEAN, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") start_sigma = model_sampling.percent_to_sigma(sde_start_percent) end_sigma = model_sampling.percent_to_sigma(sde_end_percent) @@ -591,7 +654,9 @@ class SamplerSASolver(ComfyNodeABC): "simple_order_2": simple_order_2, }, ) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute class Noise_EmptyNoise: @@ -612,30 +677,31 @@ class Noise_RandomNoise: batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) -class SamplerCustom: +class SamplerCustom(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") - - FUNCTION = "sample" - - CATEGORY = "sampling/custom_sampling" - - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -664,52 +730,58 @@ class SamplerCustom: out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + + sample = execute class Guider_Basic(comfy.samplers.CFGGuider): def set_conds(self, positive): self.inner_set_conds({"positive": positive}) -class BasicGuider: +class BasicGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "conditioning": ("CONDITIONING", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, conditioning): + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: guider = Guider_Basic(model) guider.set_conds(conditioning) - return (guider,) + return io.NodeOutput(guider) -class CFGGuider: + get_guider = execute + +class CFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, positive, negative, cfg): + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: guider = comfy.samplers.CFGGuider(model) guider.set_conds(positive, negative) guider.set_cfg(cfg) - return (guider,) + return io.NodeOutput(guider) + + get_guider = execute class Guider_DualCFG(comfy.samplers.CFGGuider): def set_cfg(self, cfg1, cfg2, nested=False): @@ -740,84 +812,88 @@ class Guider_DualCFG(comfy.samplers.CFGGuider): out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 -class DualCFGGuider: +class DualCFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "cond1": ("CONDITIONING", ), - "cond2": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "style": (["regular", "nested"],), - } - } + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style): + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) - return (guider,) + return io.NodeOutput(guider) -class DisableNoise: + get_guider = execute + +class DisableNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required":{ - } - } + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("NOISE",) - FUNCTION = "get_noise" - CATEGORY = "sampling/custom_sampling/noise" - - def get_noise(self): - return (Noise_EmptyNoise(),) - - -class RandomNoise(DisableNoise): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "noise_seed": ("INT", { - "default": 0, - "min": 0, - "max": 0xffffffffffffffff, - "control_after_generate": True, - }), - } - } + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) - def get_noise(self, noise_seed): - return (Noise_RandomNoise(noise_seed),) + get_noise = execute -class SamplerCustomAdvanced: +class RandomNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"noise": ("NOISE", ), - "guider": ("GUIDER", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) - FUNCTION = "sample" + get_noise = execute - CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider, sampler, sigmas, latent_image): +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -842,28 +918,32 @@ class SamplerCustomAdvanced: out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) -class AddNoise: + sample = execute + +class AddNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "noise": ("NOISE", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - - FUNCTION = "add_noise" - - CATEGORY = "_for_testing/custom_sampling/noise" - - def add_noise(self, model, noise, sigmas, latent_image): + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: if len(sigmas) == 0: - return latent_image + return io.NodeOutput(latent_image) latent = latent_image latent_image = latent["samples"] @@ -887,46 +967,50 @@ class AddNoise: out = latent.copy() out["samples"] = noisy - return (out,) + return io.NodeOutput(out) + + add_noise = execute -NODE_CLASS_MAPPINGS = { - "SamplerCustom": SamplerCustom, - "BasicScheduler": BasicScheduler, - "KarrasScheduler": KarrasScheduler, - "ExponentialScheduler": ExponentialScheduler, - "PolyexponentialScheduler": PolyexponentialScheduler, - "LaplaceScheduler": LaplaceScheduler, - "VPScheduler": VPScheduler, - "BetaSamplingScheduler": BetaSamplingScheduler, - "SDTurboScheduler": SDTurboScheduler, - "KSamplerSelect": KSamplerSelect, - "SamplerEulerAncestral": SamplerEulerAncestral, - "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, - "SamplerLMS": SamplerLMS, - "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, - "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, - "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, - "SamplerDPMAdaptative": SamplerDPMAdaptative, - "SamplerER_SDE": SamplerER_SDE, - "SamplerSASolver": SamplerSASolver, - "SplitSigmas": SplitSigmas, - "SplitSigmasDenoise": SplitSigmasDenoise, - "FlipSigmas": FlipSigmas, - "SetFirstSigma": SetFirstSigma, - "ExtendIntermediateSigmas": ExtendIntermediateSigmas, - "SamplingPercentToSigma": SamplingPercentToSigma, +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ] - "CFGGuider": CFGGuider, - "DualCFGGuider": DualCFGGuider, - "BasicGuider": BasicGuider, - "RandomNoise": RandomNoise, - "DisableNoise": DisableNoise, - "AddNoise": AddNoise, - "SamplerCustomAdvanced": SamplerCustomAdvanced, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", -} +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() From cc6a8dcd1ad9cc9ef7602ee141174a0cea0ed4ce Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 27 Nov 2025 08:18:08 +0800 Subject: [PATCH 36/39] Dataset Processing Nodes and Improved LoRA Trainer Nodes with multi resolution supports. (#10708) * Create nodes_dataset.py * Add encoded dataset caching mechanism * make training node to work with our dataset system * allow trainer node to get different resolution dataset * move all dataset related implementation to nodes_dataset * Rewrite dataset system with new io schema * Rewrite training system with new io schema * add ui pbar * Add outputs' id/name * Fix bad id/naming * use single process instead of input list when no need * fix wrong output_list flag * use torch.load/save and fix bad behaviors --- comfy_extras/nodes_dataset.py | 1532 +++++++++++++++++++++++++++++++++ comfy_extras/nodes_train.py | 967 ++++++++++----------- nodes.py | 1 + 3 files changed, 1980 insertions(+), 520 deletions(-) create mode 100644 comfy_extras/nodes_dataset.py diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..b23867505 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1532 @@ +import logging +import os +import math +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesToSameSizeNode(ImageProcessingNode): + node_id = "ResizeImagesToSameSize" + display_name = "Resize Images to Same Size" + description = "Resize all images to the same width and height." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), + io.Combo.Input( + "mode", + options=["stretch", "crop_center", "pad"], + default="stretch", + tooltip="Resize mode.", + ), + ] + + @classmethod + def _process(cls, image, width, height, mode): + img = tensor_to_pil(image) + + if mode == "stretch": + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "crop_center": + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + if img.width != width or img.height != height: + img = img.resize((width, height), Image.Resampling.LANCZOS) + elif mode == "pad": + img.thumbnail((width, height), Image.Resampling.LANCZOS) + new_img = Image.new("RGB", (width, height), (0, 0, 0)) + paste_x = (width - img.width) // 2 + paste_y = (height - img.height) // 2 + new_img.paste(img, (paste_x, paste_y)) + img = new_img + + return pil_to_tensor(img) + + +class ResizeImagesToPixelCountNode(ImageProcessingNode): + node_id = "ResizeImagesToPixelCount" + display_name = "Resize Images to Pixel Count" + description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "pixel_count", + default=512 * 512, + min=1, + max=8192 * 8192, + tooltip="Target pixel count.", + ), + io.Int.Input( + "steps", + default=64, + min=1, + max=128, + tooltip="The stepping for resize width/height.", + ), + ] + + @classmethod + def _process(cls, image, pixel_count, steps): + img = tensor_to_pil(image) + w, h = img.size + pixel_count_ratio = math.sqrt(pixel_count / (w * h)) + new_w = int(w * pixel_count_ratio / steps) * steps + new_h = int(h * pixel_count_ratio / steps) * steps + logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class FlipImagesNode(ImageProcessingNode): + node_id = "FlipImages" + display_name = "Flip Images" + description = "Flip all images horizontally or vertically." + extra_inputs = [ + io.Combo.Input( + "direction", + options=["horizontal", "vertical"], + default="horizontal", + tooltip="Flip direction.", + ), + ] + + @classmethod + def _process(cls, image, direction): + img = tensor_to_pil(image) + if direction == "horizontal": + img = img.transpose(Image.FLIP_LEFT_RIGHT) + else: + img = img.transpose(Image.FLIP_TOP_BOTTOM) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesToSameSizeNode, + ResizeImagesToPixelCountNode, + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + FlipImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 9e6ec6780..cb24ab709 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1,15 +1,13 @@ -import datetime -import json import logging import os import numpy as np import safetensors import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from tqdm.auto import trange +from PIL import Image, ImageDraw, ImageFont +from typing_extensions import override import comfy.samplers import comfy.sd @@ -18,9 +16,9 @@ import comfy.model_management import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO from comfy.weight_adapter import adapters, adapter_maps +from comfy_api.latest import ComfyExtension, io, ui +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -56,7 +54,18 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -65,54 +74,138 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() - for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): - noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) - indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not comfy.utils.PROGRESS_BAR_ENABLED, ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False + ): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise( + self.seed + i * 1000 ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + total_loss = 0 + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) - if (i+1) % self.grad_acc == 0: + if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -134,233 +227,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - -class LoadImageTextSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - "width": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The width to resize the images to. -1 means use the original width.", - }, - ), - "height": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The height to resize the images to. -1 means use the original height.", - }, - ) - }, - } - - RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder, clip, resize_method, width=None, height=None): - if clip is None: - raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - - logging.info(f"Loading images from folder: {folder}") - - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - - image_files = [] - for item in os.listdir(sub_input_dir): - path = os.path.join(sub_input_dir, item) - if any(item.lower().endswith(ext) for ext in valid_extensions): - image_files.append(path) - elif os.path.isdir(path): - # Support kohya-ss/sd-scripts folder structure - repeat = 1 - if item.split("_")[0].isdigit(): - repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) - - caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files - ] - captions = [] - for caption_file in caption_file_path: - caption_path = os.path.join(sub_input_dir, caption_file) - if os.path.exists(caption_path): - with open(caption_path, "r", encoding="utf-8") as f: - caption = f.read().strip() - captions.append(caption) - else: - captions.append("") - - width = width if width != -1 else None - height = height if height != -1 else None - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) - - logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - - logging.info(f"Encoding captions from {sub_input_dir}.") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) - for text in captions: - if text == "": - conditions.append(empty_cond) - tokens = clip.tokenize(text) - conditions.extend(clip.encode_from_tokens_scheduled(tokens)) - logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") - return (output_tensor, conditions) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -379,10 +245,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -396,12 +266,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -412,130 +283,126 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "grad_accumulation_steps": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 1024, - "step": 1, - "tooltip": "The number of gradient accumulation steps to use for training.", - } + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - "algorithm": ( - list(adapter_maps.keys()), - {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", ), - "gradient_checkpointing": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Use gradient checkpointing for training.", - } + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), - }, - } + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, + @classmethod + def execute( + cls, model, latents, positive, @@ -553,13 +420,74 @@ class TrainLoraNode: gradient_checkpointing, existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images @@ -591,9 +519,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -605,7 +531,9 @@ class TrainLoraNode: adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -629,7 +557,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -657,24 +587,31 @@ class TrainLoraNode: # setup models if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): patch(m) mp.model.requires_grad_(False) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) + train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, - total_steps=steps*grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input @@ -684,12 +621,15 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat(num_images, 1, 1, 1) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, - seed=noise.seed + seed=noise.seed, ) finally: for m in mp.model.modules(): @@ -702,111 +642,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -849,47 +796,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/nodes.py b/nodes.py index f4835c02e..bf73eb90e 100644 --- a/nodes.py +++ b/nodes.py @@ -2278,6 +2278,7 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_train.py", + "nodes_dataset.py", "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", From eaf68c9b5bbfbcdac8988741f3948678c9465c1d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:25:32 -0800 Subject: [PATCH 37/39] Make lora training work on Z Image and remove some redundant nodes. (#10927) --- comfy/ldm/lumina/model.py | 4 +- comfy_extras/nodes_dataset.py | 102 +--------------------------------- 2 files changed, 3 insertions(+), 103 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c8643eb82..565400b54 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -509,7 +509,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 @@ -525,7 +525,7 @@ class NextDiT(nn.Module): if self.pad_tokens_multiple is not None: pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index b23867505..4789d7d53 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,6 +1,5 @@ import logging import os -import math import json import numpy as np @@ -624,79 +623,6 @@ class TextProcessingNode(io.ComfyNode): # ========== Image Transform Nodes ========== -class ResizeImagesToSameSizeNode(ImageProcessingNode): - node_id = "ResizeImagesToSameSize" - display_name = "Resize Images to Same Size" - description = "Resize all images to the same width and height." - extra_inputs = [ - io.Int.Input("width", default=512, min=1, max=8192, tooltip="Target width."), - io.Int.Input("height", default=512, min=1, max=8192, tooltip="Target height."), - io.Combo.Input( - "mode", - options=["stretch", "crop_center", "pad"], - default="stretch", - tooltip="Resize mode.", - ), - ] - - @classmethod - def _process(cls, image, width, height, mode): - img = tensor_to_pil(image) - - if mode == "stretch": - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "crop_center": - left = max(0, (img.width - width) // 2) - top = max(0, (img.height - height) // 2) - right = min(img.width, left + width) - bottom = min(img.height, top + height) - img = img.crop((left, top, right, bottom)) - if img.width != width or img.height != height: - img = img.resize((width, height), Image.Resampling.LANCZOS) - elif mode == "pad": - img.thumbnail((width, height), Image.Resampling.LANCZOS) - new_img = Image.new("RGB", (width, height), (0, 0, 0)) - paste_x = (width - img.width) // 2 - paste_y = (height - img.height) // 2 - new_img.paste(img, (paste_x, paste_y)) - img = new_img - - return pil_to_tensor(img) - - -class ResizeImagesToPixelCountNode(ImageProcessingNode): - node_id = "ResizeImagesToPixelCount" - display_name = "Resize Images to Pixel Count" - description = "Resize images so that the total pixel count matches the specified number while preserving aspect ratio." - extra_inputs = [ - io.Int.Input( - "pixel_count", - default=512 * 512, - min=1, - max=8192 * 8192, - tooltip="Target pixel count.", - ), - io.Int.Input( - "steps", - default=64, - min=1, - max=128, - tooltip="The stepping for resize width/height.", - ), - ] - - @classmethod - def _process(cls, image, pixel_count, steps): - img = tensor_to_pil(image) - w, h = img.size - pixel_count_ratio = math.sqrt(pixel_count / (w * h)) - new_w = int(w * pixel_count_ratio / steps) * steps - new_h = int(h * pixel_count_ratio / steps) * steps - logging.info(f"Resizing from {w}x{h} to {new_w}x{new_h}") - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) - - class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" display_name = "Resize Images by Shorter Edge" @@ -801,29 +727,6 @@ class RandomCropImagesNode(ImageProcessingNode): return pil_to_tensor(img) -class FlipImagesNode(ImageProcessingNode): - node_id = "FlipImages" - display_name = "Flip Images" - description = "Flip all images horizontally or vertically." - extra_inputs = [ - io.Combo.Input( - "direction", - options=["horizontal", "vertical"], - default="horizontal", - tooltip="Flip direction.", - ), - ] - - @classmethod - def _process(cls, image, direction): - img = tensor_to_pil(image) - if direction == "horizontal": - img = img.transpose(Image.FLIP_LEFT_RIGHT) - else: - img = img.transpose(Image.FLIP_TOP_BOTTOM) - return pil_to_tensor(img) - - class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" display_name = "Normalize Images" @@ -1470,7 +1373,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) @@ -1496,13 +1399,10 @@ class DatasetExtension(ComfyExtension): SaveImageDataSetToFolderNode, SaveImageTextDataSetToFolderNode, # Image transform nodes - ResizeImagesToSameSizeNode, - ResizeImagesToPixelCountNode, ResizeImagesByShorterEdgeNode, ResizeImagesByLongerEdgeNode, CenterCropImagesNode, RandomCropImagesNode, - FlipImagesNode, NormalizeImagesNode, AdjustBrightnessNode, AdjustContrastNode, From c38e7d6599be1bdce580ccfdbb20b928315af05e Mon Sep 17 00:00:00 2001 From: Haoming <73768377+Haoming02@users.noreply.github.com> Date: Thu, 27 Nov 2025 12:28:44 +0800 Subject: [PATCH 38/39] block info (#10841) --- comfy/ldm/flux/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1a24e6d95..d5674dea6 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -171,7 +171,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -215,7 +218,10 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} From f17251bec65b5760cfedec29eace7d77f4b35130 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:03:03 +1000 Subject: [PATCH 39/39] Account for the VRAM cost of weight offloading (#10733) * mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model. --- comfy/model_management.py | 6 ++-- comfy/model_patcher.py | 59 +++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80..9c403d580 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,7 +1012,7 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 +NUM_STREAMS = 0 if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) @@ -1030,7 +1030,7 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: return None if device in STREAMS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):