From b41ab53b6f289b4d7688ab96e0a06248ec1fd86b Mon Sep 17 00:00:00 2001 From: Bedovyy <137917911+bedovyy@users.noreply.github.com> Date: Thu, 16 Apr 2026 23:11:58 +0900 Subject: [PATCH 01/13] Use `ErnieTEModel_` not `ErnieTEModel`. (#13431) --- comfy/text_encoders/ernie.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/ernie.py b/comfy/text_encoders/ernie.py index 2c7df78fe..46d24d222 100644 --- a/comfy/text_encoders/ernie.py +++ b/comfy/text_encoders/ernie.py @@ -35,4 +35,4 @@ def te(dtype_llama=None, llama_quantization_metadata=None): model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) - return ErnieTEModel + return ErnieTEModel_ From d0c53c50c2a1edf11aa63967d09aa3efbfd43cfe Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 17 Apr 2026 04:32:04 +0300 Subject: [PATCH 02/13] feat(api-nodes): add 1080p resolution for SeeDance 2.0 model (#13437) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_bytedance.py | 38 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 1cca72f6e..429c32444 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1066,7 +1066,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( ) -def _seedance2_text_inputs(): +def _seedance2_text_inputs(resolutions: list[str]): return [ IO.String.Input( "prompt", @@ -1076,7 +1076,7 @@ def _seedance2_text_inputs(): ), IO.Combo.Input( "resolution", - options=["480p", "720p"], + options=resolutions, tooltip="Resolution of the output video.", ), IO.Combo.Input( @@ -1114,8 +1114,8 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()), - IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()), + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), ], tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", ), @@ -1152,11 +1152,14 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): ( $rate480 := 10044; $rate720 := 21600; + $rate1080 := 48800; $m := widgets.model; $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "720p" ? $rate720 : $rate480; + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; $cost := $dur * $rate * $pricePer1K / 1000; {"type": "usd", "usd": $cost, "format": {"approximate": true}} ) @@ -1195,6 +1198,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, + max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1212,8 +1216,8 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs()), - IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs()), + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), ], tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", ), @@ -1259,11 +1263,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): ( $rate480 := 10044; $rate720 := 21600; + $rate1080 := 48800; $m := widgets.model; $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "720p" ? $rate720 : $rate480; + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; $cost := $dur * $rate * $pricePer1K / 1000; {"type": "usd", "usd": $cost, "format": {"approximate": true}} ) @@ -1324,13 +1331,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), poll_interval=9, + max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) -def _seedance2_reference_inputs(): +def _seedance2_reference_inputs(resolutions: list[str]): return [ - *_seedance2_text_inputs(), + *_seedance2_text_inputs(resolutions), IO.Autogrow.Input( "reference_images", template=IO.Autogrow.TemplateNames( @@ -1382,8 +1390,8 @@ class ByteDance2ReferenceNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs()), - IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs()), + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])), ], tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", ), @@ -1423,13 +1431,16 @@ class ByteDance2ReferenceNode(IO.ComfyNode): ( $rate480 := 10044; $rate720 := 21600; + $rate1080 := 48800; $m := widgets.model; $hasVideo := $lookup(inputGroups, "model.reference_videos") > 0; $noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; $videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "720p" ? $rate720 : $rate480; + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; $noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000; $minVideoFactor := $ceil($dur * 5 / 3); $minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000; @@ -1559,6 +1570,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): status_extractor=lambda r: r.status, price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), poll_interval=9, + max_poll_attempts=180, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) From 1391579c33db4921a5d40c7e0e71a938b28eb047 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:20:16 -0700 Subject: [PATCH 03/13] Add JsonExtractString node. (#13435) --- comfy_extras/nodes_string.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 75a8bb4ee..604076c4e 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,4 +1,5 @@ import re +import json from typing_extensions import override from comfy_api.latest import ComfyExtension, io @@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode): return io.NodeOutput(result) +class JsonExtractString(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="JsonExtractString", + display_name="Extract String from JSON", + category="utils/string", + search_aliases=["json", "extract json", "parse json", "json value", "read json"], + inputs=[ + io.String.Input("json_string", multiline=True), + io.String.Input("key", multiline=False), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, json_string, key): + try: + data = json.loads(json_string) + if isinstance(data, dict) and key in data: + value = data[key] + if value is None: + return io.NodeOutput("") + + return io.NodeOutput(str(value)) + + return io.NodeOutput("") + + except (json.JSONDecodeError, TypeError): + return io.NodeOutput("") + class StringExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -390,6 +424,7 @@ class StringExtension(ComfyExtension): RegexMatch, RegexExtract, RegexReplace, + JsonExtractString, ] async def comfy_entrypoint() -> StringExtension: From c033bbf516ad8fcd079b45c318e73ee8b5e22962 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 17 Apr 2026 00:26:35 -0400 Subject: [PATCH 04/13] ComfyUI v0.19.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 3c6dac3d9..98b8337b4 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.19.1" +__version__ = "0.19.2" diff --git a/pyproject.toml b/pyproject.toml index 006ed9985..c4b006486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.19.1" +version = "0.19.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 05f75311489c94e905d958c2bc4b22db5be78699 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 18 Apr 2026 02:20:09 +1000 Subject: [PATCH 05/13] nodes_textgen: Implement use_default_template for LTX (#13451) --- comfy_extras/nodes_textgen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index eed26c582..1f46d820f 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -161,12 +161,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) class TextgenExtension(ComfyExtension): From 541fd10bbe5ac5e963619bb9594c4993f977e9e1 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 17 Apr 2026 19:44:08 +0300 Subject: [PATCH 06/13] fix(api-nodes): corrected StabilityAI price badges (#13454) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_stability.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 9ef13c83b..906d8ff35 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", + expr="""{"type":"usd","usd":0.4}""", ), ) @@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", + expr="""{"type":"usd","usd":0.6}""", ), ) @@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.01}""", + expr="""{"type":"usd","usd":0.02}""", ), ) From 4f48be41388f67022d58f4f07f2f785adb8bfeea Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 17 Apr 2026 20:02:06 +0300 Subject: [PATCH 07/13] feat(api-nodes): add new "arrow-1.1" and "arrow-1.1-max" SVG models (#13447) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_quiver.py | 135 +++++++++++++++----------------- 1 file changed, 65 insertions(+), 70 deletions(-) diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index 61533263f..28862e368 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -17,6 +17,44 @@ from comfy_api_nodes.util import ( ) from comfy_extras.nodes_images import SVG +_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"] + + +def _arrow_sampling_inputs(): + """Shared sampling inputs for all Arrow model variants.""" + return [ + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ] + class QuiverTextToSVGNode(IO.ComfyNode): @classmethod @@ -39,6 +77,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): default="", tooltip="Additional style or formatting guidance.", optional=True, + advanced=True, ), IO.Autogrow.Input( "reference_images", @@ -53,43 +92,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): ), IO.DynamicCombo.Input( "model", - options=[ - IO.DynamicCombo.Option( - "arrow-preview", - [ - IO.Float.Input( - "temperature", - default=1.0, - min=0.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Randomness control. Higher values increase randomness.", - advanced=True, - ), - IO.Float.Input( - "top_p", - default=1.0, - min=0.05, - max=1.0, - step=0.05, - display_mode=IO.NumberDisplay.slider, - tooltip="Nucleus sampling parameter.", - advanced=True, - ), - IO.Float.Input( - "presence_penalty", - default=0.0, - min=-2.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Token presence penalty.", - advanced=True, - ), - ], - ), - ], + options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS], tooltip="Model to use for SVG generation.", ), IO.Int.Input( @@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.429}""", + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model, "max") + ? {"type":"usd","usd":0.3575} + : $contains(widgets.model, "preview") + ? {"type":"usd","usd":0.429} + : {"type":"usd","usd":0.286} + ) + """, ), ) @@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode): "auto_crop", default=False, tooltip="Automatically crop to the dominant subject.", + advanced=True, ), IO.DynamicCombo.Input( "model", options=[ IO.DynamicCombo.Option( - "arrow-preview", + m, [ IO.Int.Input( "target_size", @@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode): min=128, max=4096, tooltip="Square resize target in pixels.", - ), - IO.Float.Input( - "temperature", - default=1.0, - min=0.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Randomness control. Higher values increase randomness.", - advanced=True, - ), - IO.Float.Input( - "top_p", - default=1.0, - min=0.05, - max=1.0, - step=0.05, - display_mode=IO.NumberDisplay.slider, - tooltip="Nucleus sampling parameter.", - advanced=True, - ), - IO.Float.Input( - "presence_penalty", - default=0.0, - min=-2.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Token presence penalty.", advanced=True, ), + *_arrow_sampling_inputs(), ], - ), + ) + for m in _ARROW_MODELS ], tooltip="Model to use for SVG vectorization.", ), @@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.429}""", + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model, "max") + ? {"type":"usd","usd":0.3575} + : $contains(widgets.model, "preview") + ? {"type":"usd","usd":0.429} + : {"type":"usd","usd":0.286} + ) + """, ), ) From f8d92cf3138092050955fabf9172b3defcd89484 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sat, 18 Apr 2026 01:16:39 +0800 Subject: [PATCH 08/13] chore: update workflow templates to v0.9.57 (#13455) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e45a20aaf..3de845f48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.11 -comfyui-workflow-templates==0.9.54 +comfyui-workflow-templates==0.9.57 comfyui-embedded-docs==0.4.3 torch torchsde From 9635c2ec9b92f8fa1113660ace7660c1fea67e0e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 17 Apr 2026 20:31:37 +0300 Subject: [PATCH 09/13] fix(api-nodes): make "obj" output optional in Hunyuan3D Text and Image to 3D (#13449) Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_hunyuan3d.py | 34 ++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 44c94a98e..5fc31bccd 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False) + obj_result = None + if obj_file_response: + obj_result = await download_and_extract_obj_zip(obj_file_response.Url) return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id ), - obj_result.obj, - obj_result.texture, + obj_result.obj if obj_result else None, + obj_result.texture if obj_result else None, ) @@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False) + if obj_file_response: + obj_result = await download_and_extract_obj_zip(obj_file_response.Url) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + obj_result.obj, + obj_result.texture, + obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), + obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), + obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), + ) return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id ), - obj_result.obj, - obj_result.texture, - obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), - obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), - obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), + None, + None, + None, + None, + None, ) From 3086026401180c9216bcb6ace442a4e3587d2c66 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 17 Apr 2026 13:35:01 -0400 Subject: [PATCH 10/13] ComfyUI v0.19.3 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 98b8337b4..2a1eb9905 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.19.2" +__version__ = "0.19.3" diff --git a/pyproject.toml b/pyproject.toml index c4b006486..8fa92ecbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.19.2" +version = "0.19.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From b9dedea57d9f8be9861811aef3ced3e221eb8068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 19 Apr 2026 06:02:01 +0300 Subject: [PATCH 11/13] feat: SUPIR model support (CORE-17) (#13250) --- .../modules/diffusionmodules/openaimodel.py | 30 ++- comfy/ldm/supir/__init__.py | 0 comfy/ldm/supir/supir_modules.py | 226 ++++++++++++++++++ comfy/ldm/supir/supir_patch.py | 103 ++++++++ comfy/model_patcher.py | 4 + comfy_extras/nodes_model_patch.py | 104 ++++++++ comfy_extras/nodes_post_processing.py | 224 +++++++++++++++++ 7 files changed, 680 insertions(+), 11 deletions(-) create mode 100644 comfy/ldm/supir/__init__.py create mode 100644 comfy/ldm/supir/supir_modules.py create mode 100644 comfy/ldm/supir/supir_patch.py diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 295310df6..4b92c44cf 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -34,6 +34,16 @@ class TimestepBlock(nn.Module): #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue + if isinstance(layer, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, TimestepBlock): @@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: - if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: - found_patched = False - for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: - if isinstance(layer, class_type): - x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) - found_patched = True - break - if found_patched: - continue x = layer(x) return x @@ -894,6 +895,12 @@ class UNetModel(nn.Module): h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + if "middle_block_after_patch" in transformer_patches: + patch = transformer_patches["middle_block_after_patch"] + for p in patch: + out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y, + "timesteps": timesteps, "transformer_options": transformer_options}) + h = out["h"] for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) @@ -905,8 +912,9 @@ class UNetModel(nn.Module): for p in patch: h, hsp = p(h, hsp, transformer_options) - h = th.cat([h, hsp], dim=1) - del hsp + if hsp is not None: + h = th.cat([h, hsp], dim=1) + del hsp if len(hs) > 0: output_shape = hs[-1].shape else: diff --git a/comfy/ldm/supir/__init__.py b/comfy/ldm/supir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/supir/supir_modules.py b/comfy/ldm/supir/supir_modules.py new file mode 100644 index 000000000..7389b01d2 --- /dev/null +++ b/comfy/ldm/supir/supir_modules.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer +from comfy.ldm.modules.attention import optimized_attention + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None): + super().__init__() + + ks = 3 + pw = ks // 2 + + self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device) + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device), + nn.SiLU() + ) + self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + + self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device) + self.pre_concat = bool(concat_channels != 0) + + def forward(self, c, h, h_ori=None, control_scale=1): + if h_ori is not None and self.pre_concat: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + h = h + self.zero_conv(c) + if h_ori is not None and self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + h = self.param_free_norm(h) + h = torch.addcmul(h + beta, h, gamma) + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return torch.lerp(h_raw, h, control_scale) + + +class _CrossAttnInner(nn.Module): + """Inner cross-attention module matching the state_dict layout of the original CrossAttention.""" + def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), + ) + + def forward(self, x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return self.to_out(optimized_attention(q, k, v, self.heads)) + + +class ZeroCrossAttn(nn.Module): + def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None): + super().__init__() + heads = query_dim // 64 + dim_head = 64 + self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations) + self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device) + self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device) + + def forward(self, context, x, control_scale=1): + b, c, h, w = x.shape + x_in = x + + x = self.attn( + self.norm1(x).flatten(2).transpose(1, 2), + self.norm2(context).flatten(2).transpose(1, 2), + ).transpose(1, 2).unflatten(2, (h, w)) + + return x_in + x * control_scale + + +class GLVControl(nn.Module): + """SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only).""" + def __init__( + self, + in_channels=4, + model_channels=320, + num_res_blocks=2, + attention_resolutions=(4, 2), + channel_mult=(1, 2, 4), + num_head_channels=64, + transformer_depth=(1, 2, 10), + context_dim=2048, + adm_in_channels=2816, + use_linear_in_transformer=True, + use_checkpoint=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.model_channels = model_channels + time_embed_dim = model_channels * 4 + + self.time_embed = nn.Sequential( + operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + + self.label_emb = nn.Sequential( + nn.Sequential( + operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + ) + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + ]) + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels, + dtype=dtype, device=device, operations=operations) + ] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[level], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential( + Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations) + ) + ) + ds *= 2 + + num_heads = ch // num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[-1], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations), + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + ) + + self.input_hint_block = TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x, timesteps, xt, context=None, y=None, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + self.label_emb(y) + + guided_hint = self.input_hint_block(x, emb, context) + + hs = [] + h = xt + for module in self.input_blocks: + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + hs.append(h) + return hs + + +class SUPIR(nn.Module): + """ + SUPIR model containing GLVControl (control encoder) and project_modules (adapters). + State dict keys match the original SUPIR checkpoint layout: + control_model.* -> GLVControl + project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn + """ + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + + self.control_model = GLVControl(dtype=dtype, device=device, operations=operations) + + project_channel_scale = 2 + cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3 + project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3] + concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0] + cross_attn_insert_idx = [6, 3] + + self.project_modules = nn.ModuleList() + for i in range(len(cond_output_channels)): + self.project_modules.append(ZeroSFT( + project_channels[i], cond_output_channels[i], + concat_channels=concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) + + for i in cross_attn_insert_idx: + self.project_modules.insert(i, ZeroCrossAttn( + cond_output_channels[i], concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) diff --git a/comfy/ldm/supir/supir_patch.py b/comfy/ldm/supir/supir_patch.py new file mode 100644 index 000000000..b67ab4cd8 --- /dev/null +++ b/comfy/ldm/supir/supir_patch.py @@ -0,0 +1,103 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample + + +class SUPIRPatch: + """ + Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters). + Runs GLVControl lazily on first patch invocation per step, applies adapters through + middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch. + """ + SIGMA_MAX = 14.6146 + + def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end): + self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl + self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn + self.hint_latent = hint_latent # encoded LQ image latent + self.strength_start = strength_start + self.strength_end = strength_end + self.cached_features = None + self.adapter_idx = 0 + self.control_idx = 0 + self.current_control_idx = 0 + self.active = True + + def _ensure_features(self, kwargs): + """Run GLVControl on first call per step, cache results.""" + if self.cached_features is not None: + return + x = kwargs["x"] + b = x.shape[0] + hint = self.hint_latent.to(device=x.device, dtype=x.dtype) + if hint.shape[0] != b: + hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b] + self.cached_features = self.model_patch.model.control_model( + hint, kwargs["timesteps"], x, + kwargs["context"], kwargs["y"] + ) + self.adapter_idx = len(self.project_modules) - 1 + self.control_idx = len(self.cached_features) - 1 + + def _get_control_scale(self, kwargs): + if self.strength_start == self.strength_end: + return self.strength_end + sigma = kwargs["transformer_options"].get("sigmas") + if sigma is None: + return self.strength_end + s = sigma[0].item() if sigma.dim() > 0 else sigma.item() + t = min(s / self.SIGMA_MAX, 1.0) + return t * (self.strength_start - self.strength_end) + self.strength_end + + def middle_after(self, kwargs): + """middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block.""" + self.cached_features = None # reset from previous step + self.current_scale = self._get_control_scale(kwargs) + self.active = self.current_scale > 0 + if not self.active: + return {"h": kwargs["h"]} + self._ensure_features(kwargs) + h = kwargs["h"] + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return {"h": h} + + def output_block(self, h, hsp, transformer_options): + """output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat.""" + if not self.active: + return h, hsp + self.current_control_idx = self.control_idx + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return h, None + + def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw): + """forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample.""" + block_type, _ = transformer_options["block"] + if block_type == "output" and self.active and self.cached_features is not None: + x = self.project_modules[self.adapter_idx]( + self.cached_features[self.current_control_idx], x, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + return layer(x, output_shape=output_shape) + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.cached_features = None + if self.hint_latent is not None: + self.hint_latent = self.hint_latent.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + def register(self, model_patcher): + """Register all patches on a cloned model patcher.""" + model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch") + model_patcher.set_model_output_block_patch(self.output_block) + model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6deb71e12..93d19d6fe 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -506,6 +506,10 @@ class ModelPatcher: def set_model_noise_refiner_patch(self, patch): self.set_model_patch(patch, "noise_refiner") + def set_model_middle_block_after_patch(self, patch): + self.set_model_patch(patch, "middle_block_after_patch") + + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 176e6bc2f..748559a6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,7 +7,10 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +import comfy.ldm.supir.supir_modules from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel +from comfy_api.latest import io +from comfy.ldm.supir.supir_patch import SUPIRPatch class BlockWiseControlBlock(torch.nn.Module): @@ -266,6 +269,27 @@ class ModelPatchLoader: out_dim=sd["audio_proj.norm.weight"].shape[0], device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) + elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: + prefix_replace = {} + if 'model.control_model.input_hint_block.0.weight' in sd: + prefix_replace["model.control_model."] = "control_model." + prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + else: + prefix_replace["control_model."] = "control_model." + prefix_replace["project_modules."] = "project_modules." + + # Extract denoise_encoder weights before filter_keys discards them + de_prefix = "first_stage_model.denoise_encoder." + denoise_encoder_sd = {} + for k in list(sd.keys()): + if k.startswith(de_prefix): + denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k) + + sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True) + sd.pop("control_model.mask_LQ", None) + model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + if denoise_encoder_sd: + model.denoise_encoder_sd = denoise_encoder_sd model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model.load_state_dict(sd, assign=model_patcher.is_dynamic()) @@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module): ) +class SUPIRApply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SUPIRApply", + category="model_patches/supir", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the start of sampling (high sigma)."), + io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."), + io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True, + tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."), + io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Sigma threshold below which restore_cfg is disabled."), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def _encode_with_denoise_encoder(cls, vae, model_patch, image): + """Encode using denoise_encoder weights from SUPIR checkpoint if available.""" + denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None) + if not denoise_sd: + return vae.encode(image) + + # Clone VAE patcher, apply denoise_encoder weights to clone, encode + orig_patcher = vae.patcher + vae.patcher = orig_patcher.clone() + patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()} + vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0) + try: + return vae.encode(image) + finally: + vae.patcher = orig_patcher + + @classmethod + def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type, + strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput: + model_patched = model.clone() + hint_latent = model.get_model_object("latent_format").process_in( + cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3])) + patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end) + patch.register(model_patched) + + if restore_cfg > 0.0: + # Round-trip to match original pipeline: decode hint, re-encode with regular VAE + latent_format = model.get_model_object("latent_format") + decoded = vae.decode(latent_format.process_out(hint_latent)) + x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3])) + sigma_max = 14.6146 + + def restore_cfg_function(args): + denoised = args["denoised"] + sigma = args["sigma"] + if sigma.dim() > 0: + s = sigma[0].item() + else: + s = sigma.item() + if s > restore_cfg_s_tmin: + ref = x_center.to(device=denoised.device, dtype=denoised.dtype) + b = denoised.shape[0] + if ref.shape[0] != b: + ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b] + sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma + d_center = denoised - ref + denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg) + return denoised + + model_patched.set_model_sampler_post_cfg_function(restore_cfg_function) + + return io.NodeOutput(model_patched) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, + "SUPIRApply": SUPIRApply, } diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 9037c3d20..c932b747a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,6 +6,7 @@ from PIL import Image import math from enum import Enum from typing import TypedDict, Literal +import kornia import comfy.utils import comfy.model_management @@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): return io.NodeOutput(batched) +class ColorTransfer(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorTransfer", + category="image/postprocessing", + description="Match the colors of one image to another using various algorithms.", + search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], + inputs=[ + io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), + io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), + io.DynamicCombo.Input("source_stats", + tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", + options=[ + io.DynamicCombo.Option("per_frame", []), + io.DynamicCombo.Option("uniform", []), + io.DynamicCombo.Option("target_frame", [ + io.Int.Input("target_index", default=0, min=0, max=10000, + tooltip="Frame index used as the source baseline for computing the transform to image_ref"), + ]), + ]), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Image.Output(display_name="image"), + ], + ) + + @staticmethod + def _to_lab(images, i, device): + return kornia.color.rgb_to_lab( + images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2)) + + @staticmethod + def _pool_stats(images, device, is_reinhard, eps): + """Two-pass pooled mean + std/cov across all frames.""" + N, C = images.shape[0], images.shape[3] + HW = images.shape[1] * images.shape[2] + mean = torch.zeros(C, 1, device=device, dtype=torch.float32) + for i in range(N): + mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True) + mean /= N + acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32) + for i in range(N): + centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean + if is_reinhard: + acc += (centered * centered).mean(dim=-1, keepdim=True) + else: + acc += centered @ centered.T / HW + if is_reinhard: + return mean, torch.sqrt(acc / N).clamp_min_(eps) + return mean, acc / N + + @staticmethod + def _frame_stats(lab_flat, hw, is_reinhard, eps): + """Per-frame mean + std/cov.""" + mean = lab_flat.mean(dim=-1, keepdim=True) + if is_reinhard: + return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps) + centered = lab_flat - mean + return mean, centered @ centered.T / hw + + @staticmethod + def _mkl_matrix(cov_s, cov_r, eps): + """Compute MKL 3x3 transform matrix from source and ref covariances.""" + eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s) + sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps) + + scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0) + mid = scaled_V.T @ cov_r @ scaled_V + eig_val_m, eig_vec_m = torch.linalg.eigh(mid) + sqrt_m = torch.sqrt(eig_val_m.clamp_min(0)) + + inv_sqrt_s = 1.0 / sqrt_val_s + inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0) + M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T + return inv_scaled_V @ M_half @ inv_scaled_V.T + + @staticmethod + def _histogram_lut(src, ref, bins=256): + """Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1].""" + s_bins = (src * (bins - 1)).long().clamp(0, bins - 1) + r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1) + s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + ones_s = torch.ones_like(src) + ones_r = torch.ones_like(ref) + s_hist.scatter_add_(1, s_bins, ones_s) + r_hist.scatter_add_(1, r_bins, ones_r) + s_cdf = s_hist.cumsum(1) + s_cdf = s_cdf / s_cdf[:, -1:] + r_cdf = r_hist.cumsum(1) + r_cdf = r_cdf / r_cdf[:, -1:] + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1) + + @classmethod + def _pooled_cdf(cls, images, device, num_bins=256): + """Build pooled CDF across all frames, one frame at a time.""" + C = images.shape[3] + hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32) + for i in range(images.shape[0]): + frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1) + hist.scatter_add_(1, bins, torch.ones_like(frame)) + cdf = hist.cumsum(1) + return cdf / cdf[:, -1:] + + @classmethod + def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B): + """Build per-frame or uniform LUT transform for histogram mode.""" + if stats_mode == 'per_frame': + return None # LUT computed per-frame in the apply loop + + r_cdf = cls._pooled_cdf(image_ref, device) + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device) + else: + s_cdf = cls._pooled_cdf(image_target, device) + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0 + + @classmethod + def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard): + """Build transform parameters for Lab-based methods. Returns a transform function.""" + eps = 1e-6 + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + single_ref = B_ref == 1 + HW = H * W + HW_ref = image_ref.shape[1] * image_ref.shape[2] + + # Precompute ref stats + if single_ref or stats_mode in ('uniform', 'target_frame'): + ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps) + + # Uniform/target_frame: precompute single affine transform + if stats_mode in ('uniform', 'target_frame'): + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_lab = cls._to_lab(image_target, ti, device).view(C, -1) + s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps) + else: + s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps) + + if is_reinhard: + scale = ref_sc / s_sc + offset = ref_mean - scale * s_mean + return lambda src_flat, **_: src_flat * scale + offset + T = cls._mkl_matrix(s_sc, ref_sc, eps) + offset = ref_mean - T @ s_mean + return lambda src_flat, **_: T @ src_flat + offset + + # per_frame + def per_frame_transform(src_flat, frame_idx): + s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps) + + if single_ref: + r_mean, r_sc = ref_mean, ref_sc + else: + ri = min(frame_idx, B_ref - 1) + r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps) + + centered = src_flat - s_mean + if is_reinhard: + return centered * (r_sc / s_sc) + r_mean + T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps) + return T @ centered + r_mean + + return per_frame_transform + + @classmethod + def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput: + stats_mode = source_stats["source_stats"] + target_index = source_stats.get("target_index", 0) + + if strength == 0 or image_ref is None: + return io.NodeOutput(image_target) + + device = comfy.model_management.get_torch_device() + intermediate_device = comfy.model_management.intermediate_device() + intermediate_dtype = comfy.model_management.intermediate_dtype() + + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + pbar = comfy.utils.ProgressBar(B) + out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype) + + if method == 'histogram': + uniform_lut = cls._build_histogram_transform( + image_target, image_ref, device, stats_mode, target_index, B) + + for i in range(B): + src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1) + src_flat = src.reshape(C, -1) + if uniform_lut is not None: + lut = uniform_lut + else: + ri = min(i, B_ref - 1) + ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + lut = cls._histogram_lut(src_flat, ref) + bin_idx = (src_flat * 255).long().clamp(0, 255) + matched = lut.gather(1, bin_idx).view(C, H, W) + result = matched if strength == 1.0 else torch.lerp(src, matched, strength) + out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + else: + transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab") + + for i in range(B): + src_frame = cls._to_lab(image_target, i, device) + corrected = transform(src_frame.view(C, -1), frame_idx=i) + if strength == 1.0: + result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W)) + else: + result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength)) + out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + + return io.NodeOutput(out) + + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension): BatchImagesNode, BatchMasksNode, BatchLatentsNode, + ColorTransfer, # BatchImagesMasksLatentsNode, ] From 3d816db07f9721525c8326bc8d525cd81f00a7fa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 18 Apr 2026 20:02:29 -0700 Subject: [PATCH 12/13] Some optimizations to make Ernie inference a bit faster. (#13472) --- comfy/ldm/ernie/model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py index f7cdb51e6..eba661aec 100644 --- a/comfy/ldm/ernie/model.py +++ b/comfy/ldm/ernie/model.py @@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module): query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - query, key = query.to(x.dtype), key.to(x.dtype) - q_flat = query.reshape(B, S, -1) k_flat = key.reshape(B, S, -1) @@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module): residual = x x_norm = self.adaLN_sa_ln(x) - x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + x_norm = x_norm * (1 + scale_msa) + shift_msa attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) - x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + x = residual + gate_msa * attn_out residual = x x_norm = self.adaLN_mlp_ln(x) - x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + x_norm = x_norm * (1 + scale_mlp) + shift_mlp - return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype) + return residual + gate_mlp * self.mlp(x_norm) class ErnieImageAdaLNContinuous(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): @@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module): def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: scale, shift = self.linear(conditioning).chunk(2, dim=-1) x = self.norm(x) - x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)) return x class ErnieImageModel(nn.Module): From 138571da955d85935baa09371ac2b67ea8b7a8ca Mon Sep 17 00:00:00 2001 From: Abdul Rehman <76230556+Abdulrehman-PIAIC80387@users.noreply.github.com> Date: Sun, 19 Apr 2026 08:21:22 +0500 Subject: [PATCH 13/13] fix: append directory type annotation to internal files endpoint response (#13078) (#13305) --- api_server/routes/internal/internal_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index b224306da..1477afa01 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -67,7 +67,7 @@ class InternalRoutes: (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) - return web.json_response([entry.name for entry in sorted_files], status=200) + return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200) def get_app(self):