From 6265a239f379f1a5cf2bfdcd3a9631d4c11e50fb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 22 Mar 2026 15:46:18 -0700 Subject: [PATCH 01/13] Add warning for users who disable dynamic vram. (#13113) --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index f99aee38e..cd4483c67 100644 --- a/main.py +++ b/main.py @@ -471,6 +471,9 @@ if __name__ == "__main__": if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + if args.disable_dynamic_vram: + logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() From da6edb5a4e5745869d64ae05b96263da42d5364e Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 24 Mar 2026 01:59:21 +0900 Subject: [PATCH 02/13] bump manager version to 4.1b8 (#13108) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 5b06b56f6..90a2be84e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b6 \ No newline at end of file +comfyui_manager==4.1b8 From e87858e9743f92222cdb478f1f835135750b6a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:22:24 +0200 Subject: [PATCH 03/13] feat: LTX2: Support reference audio (ID-LoRA) (#13111) --- comfy/ldm/lightricks/av_model.py | 42 +++++++++++++++++ comfy/model_base.py | 4 ++ comfy_extras/nodes_lt.py | 80 ++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 08d686b7b..6f2ba41ef 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel): additional_args["has_spatial_mask"] = has_spatial_mask ax, a_latent_coords = self.a_patchifier.patchify(ax) + + # Inject reference audio for ID-LoRA in-context conditioning + ref_audio = kwargs.get("ref_audio", None) + ref_audio_seq_len = 0 + if ref_audio is not None: + ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device) + if ref_tokens.shape[0] < ax.shape[0]: + ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1) + ref_audio_seq_len = ref_tokens.shape[1] + B = ax.shape[0] + + # Compute negative temporal positions matching ID-LoRA convention: + # offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0 + p = self.a_patchifier + tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate + ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device) + ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device) + time_offset = ref_end[-1].item() + tpl + ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_pos = torch.stack([ref_start, ref_end], dim=-1) + + additional_args["ref_audio_seq_len"] = ref_audio_seq_len + additional_args["target_audio_seq_len"] = ax.shape[1] + ax = torch.cat([ref_tokens, ax], dim=1) + a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) + ax = self.audio_patchify_proj(ax) # additional_args.update({"av_orig_shape": list(x.shape)}) @@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel): # Prepare audio timestep a_timestep = kwargs.get("a_timestep") + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0 and a_timestep is not None: + # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma. + target_len = kwargs.get("target_audio_seq_len") + if a_timestep.dim() <= 1: + a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) + ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype) + a_timestep = torch.cat([ref_ts, a_timestep], dim=1) if a_timestep is not None: a_timestep_scaled = a_timestep * self.timestep_scale_multiplier a_timestep_flat = a_timestep_scaled.flatten() @@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel): v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + # Trim reference audio tokens before unpatchification + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0: + ax = ax[:, ref_audio_seq_len:] + if a_embedded_timestep.shape[1] > 1: + a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:] + # Expand compressed video timestep if needed if isinstance(v_embedded_timestep, CompressedTimestep): v_embedded_timestep = v_embedded_timestep.expand() diff --git a/comfy/model_base.py b/comfy/model_base.py index bfffe2402..70aff886e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1061,6 +1061,10 @@ class LTXAV(BaseModel): if guide_attention_entries is not None: out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + ref_audio = kwargs.get("ref_audio", None) + if ref_audio is not None: + out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index c05571143..d7c2e8744 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.samplers import comfy.utils import math import numpy as np @@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode): return io.NodeOutput(video_latent, audio_latent) +class LTXVReferenceAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVReferenceAudio", + display_name="LTXV Reference Audio (ID-LoRA)", + category="conditioning/audio", + description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."), + io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."), + io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."), + ], + outputs=[ + io.Model.Output(), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: + # Encode reference audio to latents and patchify + audio_latents = audio_vae.encode(reference_audio) + b, c, t, f = audio_latents.shape + ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) + ref_audio = {"tokens": ref_tokens} + + positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio}) + negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio}) + + # Patch model with identity guidance + m = model.clone() + scale = identity_guidance_scale + model_sampling = m.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def post_cfg_function(args): + if scale == 0: + return args["denoised"] + + sigma = args["sigma"] + sigma_ = sigma[0].item() + if sigma_ > sigma_start or sigma_ < sigma_end: + return args["denoised"] + + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + model_options = args["model_options"].copy() + x = args["input"] + + # Strip ref_audio from conditioning for the no-reference pass + noref_cond = [] + for entry in cond: + new_entry = entry.copy() + mc = new_entry.get("model_conds", {}).copy() + mc.pop("ref_audio", None) + new_entry["model_conds"] = mc + noref_cond.append(new_entry) + + (pred_noref,) = comfy.samplers.calc_cond_batch( + args["model"], [noref_cond], x, sigma, model_options + ) + + return cfg_result + (cond_pred - pred_noref) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return io.NodeOutput(m, positive, negative) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension): LTXVCropGuides, LTXVConcatAVLatent, LTXVSeparateAVLatent, + LTXVReferenceAudio, ] From 2d4970ff677970fbca9f9f562296eda46de8aa4c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:43:41 -0700 Subject: [PATCH 04/13] Update frontend version to 1.42.8 (#13126) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ad0344ed4..26cc94354 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.21 +comfyui-frontend-package==1.42.8 comfyui-workflow-templates==0.9.26 comfyui-embedded-docs==0.4.3 torch From 2d5fd3f5dde51574d77601dbe4c163a95a56121a Mon Sep 17 00:00:00 2001 From: Kelly Yang <124ykl@gmail.com> Date: Tue, 24 Mar 2026 11:22:30 -0700 Subject: [PATCH 05/13] fix: set default values of Color Adjustment node to zero (#13084) Co-authored-by: Jedrzej Kosinski --- blueprints/Color Adjustment.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blueprints/Color Adjustment.json b/blueprints/Color Adjustment.json index c599f7213..47f3df783 100644 --- a/blueprints/Color Adjustment.json +++ b/blueprints/Color Adjustment.json @@ -1 +1 @@ -{"revision": 0, "last_node_id": 14, "last_link_id": 0, "nodes": [{"id": 14, "type": "36677b92-5dd8-47a5-9380-4da982c1894f", "pos": [3610, -2630], "size": [270, 150], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"], ["7", "value"], ["6", "value"]]}, "widgets_values": [], "title": "Color Adjustment"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "36677b92-5dd8-47a5-9380-4da982c1894f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 16, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Color Adjustment", "inputNode": {"id": -10, "bounding": [3110, -3560, 120, 60]}, "outputNode": {"id": -20, "bounding": [4070, -3560, 120, 60]}, "inputs": [{"id": "0431d493-5f28-4430-bd00-84733997fc08", "name": "images.image0", "type": "IMAGE", "linkIds": [29], "localized_name": "images.image0", "label": "image", "pos": [3210, -3540]}], "outputs": [{"id": "bee8ea06-a114-4612-8937-939f2c927bdb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4090, -3540]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [3590, -3940], "size": [420, 252], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 29}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 34}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 30}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 31}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 33}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [28]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // temperature (-100 to 100)\nuniform float u_float1; // tint (-100 to 100)\nuniform float u_float2; // vibrance (-100 to 100)\nuniform float u_float3; // saturation (-100 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float INPUT_SCALE = 0.01;\nconst float TEMP_TINT_PRIMARY = 0.3;\nconst float TEMP_TINT_SECONDARY = 0.15;\nconst float VIBRANCE_BOOST = 2.0;\nconst float SATURATION_BOOST = 2.0;\nconst float SKIN_PROTECTION = 0.5;\nconst float EPSILON = 0.001;\nconst vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);\n\nvoid main() {\n vec4 tex = texture(u_image0, v_texCoord);\n vec3 color = tex.rgb;\n \n // Scale inputs: -100/100 \u2192 -1/1\n float temperature = u_float0 * INPUT_SCALE;\n float tint = u_float1 * INPUT_SCALE;\n float vibrance = u_float2 * INPUT_SCALE;\n float saturation = u_float3 * INPUT_SCALE;\n \n // Temperature (warm/cool): positive = warm, negative = cool\n color.r += temperature * TEMP_TINT_PRIMARY;\n color.b -= temperature * TEMP_TINT_PRIMARY;\n \n // Tint (green/magenta): positive = green, negative = magenta\n color.g += tint * TEMP_TINT_PRIMARY;\n color.r -= tint * TEMP_TINT_SECONDARY;\n color.b -= tint * TEMP_TINT_SECONDARY;\n \n // Single clamp after temperature/tint\n color = clamp(color, 0.0, 1.0);\n \n // Vibrance with skin protection\n if (vibrance != 0.0) {\n float maxC = max(color.r, max(color.g, color.b));\n float minC = min(color.r, min(color.g, color.b));\n float sat = maxC - minC;\n float gray = dot(color, LUMA_WEIGHTS);\n \n if (vibrance < 0.0) {\n // Desaturate: -100 \u2192 gray\n color = mix(vec3(gray), color, 1.0 + vibrance);\n } else {\n // Boost less saturated colors more\n float vibranceAmt = vibrance * (1.0 - sat);\n \n // Branchless skin tone protection\n float isWarmTone = step(color.b, color.g) * step(color.g, color.r);\n float warmth = (color.r - color.b) / max(maxC, EPSILON);\n float skinTone = isWarmTone * warmth * sat * (1.0 - sat);\n vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);\n \n color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);\n }\n }\n \n // Saturation\n if (saturation != 0.0) {\n float gray = dot(color, LUMA_WEIGHTS);\n float satMix = saturation < 0.0\n ? 1.0 + saturation // -100 \u2192 gray\n : 1.0 + saturation * SATURATION_BOOST; // +100 \u2192 3x boost\n color = mix(vec3(gray), color, satMix);\n }\n \n fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);\n}", "from_input"]}, {"id": 6, "type": "PrimitiveFloat", "pos": [3290, -3610], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "vibrance", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26, 31]}], "title": "Vibrance", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 7, "type": "PrimitiveFloat", "pos": [3290, -3720], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [33]}], "title": "Saturation", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [3290, -3830], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "tint", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [30]}], "title": "Tint", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 255, 0]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 0, 255]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [3290, -3940], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "temperature", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [34]}], "title": "Temperature", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [68, 136, 255]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 136, 0]}]}, "widgets_values": [100]}], "groups": [], "links": [{"id": 34, "origin_id": 4, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 30, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 31, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 33, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 29, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} +{"revision": 0, "last_node_id": 14, "last_link_id": 0, "nodes": [{"id": 14, "type": "36677b92-5dd8-47a5-9380-4da982c1894f", "pos": [3610, -2630], "size": [270, 150], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"], ["7", "value"], ["6", "value"]]}, "widgets_values": [], "title": "Color Adjustment"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "36677b92-5dd8-47a5-9380-4da982c1894f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 16, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Color Adjustment", "inputNode": {"id": -10, "bounding": [3110, -3560, 120, 60]}, "outputNode": {"id": -20, "bounding": [4070, -3560, 120, 60]}, "inputs": [{"id": "0431d493-5f28-4430-bd00-84733997fc08", "name": "images.image0", "type": "IMAGE", "linkIds": [29], "localized_name": "images.image0", "label": "image", "pos": [3210, -3540]}], "outputs": [{"id": "bee8ea06-a114-4612-8937-939f2c927bdb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4090, -3540]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [3590, -3940], "size": [420, 252], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 29}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 34}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 30}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 31}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 33}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [28]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // temperature (-100 to 100)\nuniform float u_float1; // tint (-100 to 100)\nuniform float u_float2; // vibrance (-100 to 100)\nuniform float u_float3; // saturation (-100 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float INPUT_SCALE = 0.01;\nconst float TEMP_TINT_PRIMARY = 0.3;\nconst float TEMP_TINT_SECONDARY = 0.15;\nconst float VIBRANCE_BOOST = 2.0;\nconst float SATURATION_BOOST = 2.0;\nconst float SKIN_PROTECTION = 0.5;\nconst float EPSILON = 0.001;\nconst vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);\n\nvoid main() {\n vec4 tex = texture(u_image0, v_texCoord);\n vec3 color = tex.rgb;\n \n // Scale inputs: -100/100 \u2192 -1/1\n float temperature = u_float0 * INPUT_SCALE;\n float tint = u_float1 * INPUT_SCALE;\n float vibrance = u_float2 * INPUT_SCALE;\n float saturation = u_float3 * INPUT_SCALE;\n \n // Temperature (warm/cool): positive = warm, negative = cool\n color.r += temperature * TEMP_TINT_PRIMARY;\n color.b -= temperature * TEMP_TINT_PRIMARY;\n \n // Tint (green/magenta): positive = green, negative = magenta\n color.g += tint * TEMP_TINT_PRIMARY;\n color.r -= tint * TEMP_TINT_SECONDARY;\n color.b -= tint * TEMP_TINT_SECONDARY;\n \n // Single clamp after temperature/tint\n color = clamp(color, 0.0, 1.0);\n \n // Vibrance with skin protection\n if (vibrance != 0.0) {\n float maxC = max(color.r, max(color.g, color.b));\n float minC = min(color.r, min(color.g, color.b));\n float sat = maxC - minC;\n float gray = dot(color, LUMA_WEIGHTS);\n \n if (vibrance < 0.0) {\n // Desaturate: -100 \u2192 gray\n color = mix(vec3(gray), color, 1.0 + vibrance);\n } else {\n // Boost less saturated colors more\n float vibranceAmt = vibrance * (1.0 - sat);\n \n // Branchless skin tone protection\n float isWarmTone = step(color.b, color.g) * step(color.g, color.r);\n float warmth = (color.r - color.b) / max(maxC, EPSILON);\n float skinTone = isWarmTone * warmth * sat * (1.0 - sat);\n vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);\n \n color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);\n }\n }\n \n // Saturation\n if (saturation != 0.0) {\n float gray = dot(color, LUMA_WEIGHTS);\n float satMix = saturation < 0.0\n ? 1.0 + saturation // -100 \u2192 gray\n : 1.0 + saturation * SATURATION_BOOST; // +100 \u2192 3x boost\n color = mix(vec3(gray), color, satMix);\n }\n \n fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);\n}", "from_input"]}, {"id": 6, "type": "PrimitiveFloat", "pos": [3290, -3610], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "vibrance", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26, 31]}], "title": "Vibrance", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 7, "type": "PrimitiveFloat", "pos": [3290, -3720], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [33]}], "title": "Saturation", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [3290, -3830], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "tint", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [30]}], "title": "Tint", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 255, 0]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 0, 255]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [3290, -3940], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "temperature", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [34]}], "title": "Temperature", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [68, 136, 255]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 136, 0]}]}, "widgets_values": [0]}], "groups": [], "links": [{"id": 34, "origin_id": 4, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 30, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 31, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 33, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 29, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} From f9ec85f739aeab3fbc0f89baaa1e0fc196f2ff2c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:27:39 +0200 Subject: [PATCH 06/13] feat(api-nodes): update xAI Grok nodes (#13140) --- comfy_api_nodes/apis/grok.py | 10 +- comfy_api_nodes/nodes_grok.py | 251 ++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index c56c8aecc..fbedb53e0 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel): class VideoGenerationRequest(BaseModel): model: str = Field(...) prompt: str = Field(...) - image: InputUrlObject | None = Field(...) + image: InputUrlObject | None = Field(None) + reference_images: list[InputUrlObject] | None = Field(None) duration: int = Field(...) aspect_ratio: str | None = Field(...) resolution: str = Field(...) seed: int = Field(...) +class VideoExtensionRequest(BaseModel): + prompt: str = Field(...) + video: InputUrlObject = Field(...) + duration: int = Field(default=6) + model: str | None = Field(default=None) + + class VideoEditRequest(BaseModel): model: str = Field(...) prompt: str = Field(...) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index 0716d6239..dabc899d6 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import ( ImageGenerationResponse, InputUrlObject, VideoEditRequest, + VideoExtensionRequest, VideoGenerationRequest, VideoGenerationResponse, VideoStatusResponse, @@ -21,6 +22,7 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, upload_video_to_comfyapi, validate_string, validate_video_duration, @@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None: return None +def _extract_grok_video_price(response) -> float | None: + price = _extract_grok_price(response) + if price is not None: + return price * 1.43 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: + if model == "grok-imagine-video-beta": + model = "grok-imagine-video" image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.video.url)) +class GrokVideoReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoReferenceNode", + display_name="Grok Reference-to-Video", + category="api node/video/Grok", + description="Generate video guided by reference images as style and content references.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="reference_", + min=1, + max=7, + ), + tooltip="Up to 7 reference images to guide the video generation.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=2, + max=10, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model.duration", "model.resolution"], + input_groups=["model.reference_images"], + ), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $refs := inputGroups["model.reference_images"]; + $rate := $res = "720p" ? 0.07 : 0.05; + $price := ($rate * $dur + 0.002 * $refs) * 1.43; + {"type":"usd","usd": $price} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + ref_image_urls = await upload_images_to_comfyapi( + cls, + list(model["reference_images"].values()), + mime_type="image/png", + wait_label="Uploading base images", + max_images=7, + ) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model["model"], + reference_images=[InputUrlObject(url=i) for i in ref_image_urls], + prompt=prompt, + resolution=model["resolution"], + duration=model["duration"], + aspect_ratio=model["aspect_ratio"], + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoExtendNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoExtendNode", + display_name="Grok Video Extend", + category="api node/video/Grok", + description="Extend an existing video with a seamless continuation based on a text prompt.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of what should happen next in the video.", + ), + IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Int.Input( + "duration", + default=8, + min=2, + max=10, + step=1, + tooltip="Length of the extension in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video extension.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]), + expr=""" + ( + $dur := $lookup(widgets, "model.duration"); + { + "type": "range_usd", + "min_usd": (0.02 + 0.05 * $dur) * 1.43, + "max_usd": (0.15 + 0.05 * $dur) * 1.43 + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + video: Input.Video, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=2, max_duration=15) + video_size = get_fs_object_size(video.get_stream_source()) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"), + data=VideoExtensionRequest( + prompt=prompt, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + duration=model["duration"], + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + class GrokExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension): GrokImageNode, GrokImageEditNode, GrokVideoNode, + GrokVideoReferenceNode, GrokVideoEditNode, + GrokVideoExtendNode, ] From c2862b24af49ff40b251ea2a4e22b92af9e92982 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:36:12 -0700 Subject: [PATCH 07/13] Update templates package version. (#13141) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26cc94354..76f824906 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.8 -comfyui-workflow-templates==0.9.26 +comfyui-workflow-templates==0.9.36 comfyui-embedded-docs==0.4.3 torch torchsde From 8e73678dae6e5763bc860d6f98566243a494f9c2 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 24 Mar 2026 17:47:28 -0400 Subject: [PATCH 08/13] CURVE node (#12757) * CURVE node * remove curve to sigmas node * feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986) CurveInput is an abstract base class so future curve representations (bezier, LUT-based, analytical functions) can be added without breaking downstream nodes that type-check against CurveInput. MonotoneCubicCurve is the concrete implementation that: - Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly - Pre-computes slopes as numpy arrays at construction time - Provides vectorised interp_array() using numpy for batch evaluation - interp() for single-value evaluation - to_lut() for generating lookup tables CurveEditor node wraps raw widget points in MonotoneCubicCurve. * linear curve * refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema * feat: add HISTOGRAM type and histogram support to CurveEditor * code improve --------- Co-authored-by: Christian Byrne --- comfy_api/input/__init__.py | 8 + comfy_api/latest/_input/__init__.py | 5 + comfy_api/latest/_input/curve_types.py | 219 +++++++++++++++++++++++++ comfy_api/latest/_io.py | 20 ++- comfy_extras/nodes_curve.py | 42 +++++ nodes.py | 1 + 6 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 comfy_api/latest/_input/curve_types.py create mode 100644 comfy_extras/nodes_curve.py diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 68ff78270..16d4acfd1 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -5,6 +5,10 @@ from comfy_api.latest._input import ( MaskInput, LatentInput, VideoInput, + CurvePoint, + CurveInput, + MonotoneCubicCurve, + LinearCurve, ) __all__ = [ @@ -13,4 +17,8 @@ __all__ = [ "MaskInput", "LatentInput", "VideoInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 14f0e72f4..05cd3d40a 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,4 +1,5 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .video_types import VideoInput __all__ = [ @@ -7,4 +8,8 @@ __all__ = [ "VideoInput", "MaskInput", "LatentInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/curve_types.py b/comfy_api/latest/_input/curve_types.py new file mode 100644 index 000000000..b6dd7adf9 --- /dev/null +++ b/comfy_api/latest/_input/curve_types.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import logging +import math +from abc import ABC, abstractmethod +import numpy as np + +logger = logging.getLogger(__name__) + + +CurvePoint = tuple[float, float] + + +class CurveInput(ABC): + """Abstract base class for curve inputs. + + Subclasses represent different curve representations (control-point + interpolation, analytical functions, LUT-based, etc.) while exposing a + uniform evaluation interface to downstream nodes. + """ + + @property + @abstractmethod + def points(self) -> list[CurvePoint]: + """The control points that define this curve.""" + + @abstractmethod + def interp(self, x: float) -> float: + """Evaluate the curve at a single *x* value in [0, 1].""" + + def interp_array(self, xs: np.ndarray) -> np.ndarray: + """Vectorised evaluation over a numpy array of x values. + + Subclasses should override this for better performance. The default + falls back to scalar ``interp`` calls. + """ + return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs)) + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1].""" + return self.interp_array(np.linspace(0.0, 1.0, size)) + + @staticmethod + def from_raw(data) -> CurveInput: + """Convert raw curve data (dict or point list) to a CurveInput instance. + + Accepts: + - A ``CurveInput`` instance (returned as-is). + - A dict with ``"points"`` and optional ``"interpolation"`` keys. + - A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic). + """ + if isinstance(data, CurveInput): + return data + if isinstance(data, dict): + raw_points = data["points"] + interpolation = data.get("interpolation", "monotone_cubic") + else: + raw_points = data + interpolation = "monotone_cubic" + points = [(float(x), float(y)) for x, y in raw_points] + if interpolation == "linear": + return LinearCurve(points) + if interpolation != "monotone_cubic": + logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation) + return MonotoneCubicCurve(points) + + +class MonotoneCubicCurve(CurveInput): + """Monotone cubic Hermite interpolation over control points. + + Mirrors the frontend ``createMonotoneInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that + backend evaluation matches the editor preview exactly. + + All heavy work (sorting, slope computation) happens once at construction. + ``interp_array`` is fully vectorised with numpy. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + self._slopes = self._compute_slopes() + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def _compute_slopes(self) -> np.ndarray: + xs, ys = self._xs, self._ys + n = len(xs) + if n < 2: + return np.zeros(n, dtype=np.float64) + + dx = np.diff(xs) + dy = np.diff(ys) + dx_safe = np.where(dx == 0, 1.0, dx) + deltas = np.where(dx == 0, 0.0, dy / dx_safe) + + slopes = np.empty(n, dtype=np.float64) + slopes[0] = deltas[0] + slopes[-1] = deltas[-1] + for i in range(1, n - 1): + if deltas[i - 1] * deltas[i] <= 0: + slopes[i] = 0.0 + else: + slopes[i] = (deltas[i - 1] + deltas[i]) / 2 + + for i in range(n - 1): + if deltas[i] == 0: + slopes[i] = 0.0 + slopes[i + 1] = 0.0 + else: + alpha = slopes[i] / deltas[i] + beta = slopes[i + 1] / deltas[i] + s = alpha * alpha + beta * beta + if s > 9: + t = 3 / math.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + return slopes + + def interp(self, x: float) -> float: + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + if x <= xs[0]: + return float(ys[0]) + if x >= xs[-1]: + return float(ys[-1]) + + hi = int(np.searchsorted(xs, x, side='right')) + hi = min(hi, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + if dx == 0: + return float(ys[lo]) + + t = (x - xs[lo]) / dx + t2 = t * t + t3 = t2 * t + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + """Fully vectorised evaluation using numpy.""" + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if n == 1: + return np.full_like(xs_in, ys[0], dtype=np.float64) + + hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + dx_safe = np.where(dx == 0, 1.0, dx) + t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe) + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi] + result = np.where(xs_in <= xs[0], ys[0], result) + result = np.where(xs_in >= xs[-1], ys[-1], result) + return result + + def __repr__(self) -> str: + return f"MonotoneCubicCurve(points={self._points})" + + +class LinearCurve(CurveInput): + """Piecewise linear interpolation over control points. + + Mirrors the frontend ``createLinearInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts``. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def interp(self, x: float) -> float: + xs, ys = self._xs, self._ys + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + return float(np.interp(x, xs, ys)) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + if len(self._xs) == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if len(self._xs) == 1: + return np.full_like(xs_in, self._ys[0], dtype=np.float64) + return np.interp(xs_in, self._xs, self._ys) + + def __repr__(self) -> str: + return f"LinearCurve(points={self._points})" diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 7ca8f4e0c..1cbc8ed26 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput + from comfy_api.input import VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker @@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO): @comfytype(io_type="CURVE") class Curve(ComfyTypeIO): - CurvePoint = tuple[float, float] - Type = list[CurvePoint] + from comfy_api.input import CurvePoint + if TYPE_CHECKING: + Type = CurveInput_ class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, @@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO): if default is None: self.default = [(0.0, 0.0), (1.0, 1.0)] + def as_dict(self): + d = super().as_dict() + if self.default is not None: + d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"} + return d + + +@comfytype(io_type="HISTOGRAM") +class Histogram(ComfyTypeIO): + """A histogram represented as a list of bin counts.""" + Type = list[int] + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): @@ -2240,5 +2253,6 @@ __all__ = [ "PriceBadge", "BoundingBox", "Curve", + "Histogram", "NodeReplace", ] diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py new file mode 100644 index 000000000..9016a84f9 --- /dev/null +++ b/comfy_extras/nodes_curve.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from comfy_api.latest import ComfyExtension, io +from comfy_api.input import CurveInput +from typing_extensions import override + + +class CurveEditor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveEditor", + display_name="Curve Editor", + category="utils", + inputs=[ + io.Curve.Input("curve"), + io.Histogram.Input("histogram", optional=True), + ], + outputs=[ + io.Curve.Output("curve"), + ], + ) + + @classmethod + def execute(cls, curve, histogram=None) -> io.NodeOutput: + result = CurveInput.from_raw(curve) + + ui = {} + if histogram is not None: + ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram) + + return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result) + + +class CurveExtension(ComfyExtension): + @override + async def get_node_list(self): + return [CurveEditor] + + +async def comfy_entrypoint(): + return CurveExtension() diff --git a/nodes.py b/nodes.py index 2c4650a20..79874d051 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_sdpose.py", "nodes_math.py", "nodes_painter.py", + "nodes_curve.py", ] import_failed = [] From a0a64c679fca700a087d0cdfa419912a3e8b3bf8 Mon Sep 17 00:00:00 2001 From: Dante Date: Wed, 25 Mar 2026 07:38:08 +0900 Subject: [PATCH 09/13] Add Number Convert node (#13041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Number Convert node for unified numeric type conversion Consolidates fragmented IntToFloat/FloatToInt nodes (previously only available via third-party packs like ComfyMath, FillNodes, etc.) into a single core node. - Single input accepting INT, FLOAT, STRING, and BOOL types - Two outputs: FLOAT and INT - Conversion: bool→0/1, string→parsed number, float↔int standard cast - Follows Math Expression node patterns (comfy_api, io.Schema, etc.) Refs: COM-16925 * Register nodes_number_convert.py in extras_files list Without this entry in nodes.py, the Number Convert node file would not be discovered and loaded at startup. * Add isfinite guard, exception chaining, and unit tests for Number Convert node - Add math.isfinite() check to prevent int() crash on inf/nan string inputs - Use 'from None' for cleaner exception chaining on string parse failure - Add 21 unit tests covering all input types and error paths --- comfy_extras/nodes_number_convert.py | 79 +++++++++++ nodes.py | 1 + .../nodes_number_convert_test.py | 123 ++++++++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 comfy_extras/nodes_number_convert.py create mode 100644 tests-unit/comfy_extras_test/nodes_number_convert_test.py diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py new file mode 100644 index 000000000..b2822c856 --- /dev/null +++ b/comfy_extras/nodes_number_convert.py @@ -0,0 +1,79 @@ +"""Number Convert node for unified numeric type conversion. + +Provides a single node that converts INT, FLOAT, STRING, and BOOL +inputs into FLOAT and INT outputs. +""" + +from __future__ import annotations + +import math + +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class NumberConvertNode(io.ComfyNode): + """Converts various types to numeric FLOAT and INT outputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ComfyNumberConvert", + display_name="Number Convert", + category="math", + search_aliases=[ + "int to float", "float to int", "number convert", + "int2float", "float2int", "cast", "parse number", + "string to number", "bool to int", + ], + inputs=[ + io.MultiType.Input( + "value", + [io.Int, io.Float, io.String, io.Boolean], + display_name="value", + ), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute(cls, value) -> io.NodeOutput: + if isinstance(value, bool): + float_val = 1.0 if value else 0.0 + elif isinstance(value, (int, float)): + float_val = float(value) + elif isinstance(value, str): + text = value.strip() + if not text: + raise ValueError("Cannot convert empty string to number.") + try: + float_val = float(text) + except ValueError: + raise ValueError( + f"Cannot convert string to number: {value!r}" + ) from None + else: + raise TypeError( + f"Unsupported input type: {type(value).__name__}" + ) + + if not math.isfinite(float_val): + raise ValueError( + f"Cannot convert non-finite value to number: {float_val}" + ) + + return io.NodeOutput(float_val, int(float_val)) + + +class NumberConvertExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [NumberConvertNode] + + +async def comfy_entrypoint() -> NumberConvertExtension: + return NumberConvertExtension() diff --git a/nodes.py b/nodes.py index 79874d051..37ceac2fc 100644 --- a/nodes.py +++ b/nodes.py @@ -2454,6 +2454,7 @@ async def init_builtin_extra_nodes(): "nodes_nag.py", "nodes_sdpose.py", "nodes_math.py", + "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", ] diff --git a/tests-unit/comfy_extras_test/nodes_number_convert_test.py b/tests-unit/comfy_extras_test/nodes_number_convert_test.py new file mode 100644 index 000000000..0046fa8f4 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_number_convert_test.py @@ -0,0 +1,123 @@ +import pytest +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_number_convert import NumberConvertNode + + +class TestNumberConvertExecute: + @staticmethod + def _exec(value) -> object: + return NumberConvertNode.execute(value) + + # --- INT input --- + + def test_int_input(self): + result = self._exec(42) + assert result[0] == 42.0 + assert result[1] == 42 + + def test_int_zero(self): + result = self._exec(0) + assert result[0] == 0.0 + assert result[1] == 0 + + def test_int_negative(self): + result = self._exec(-7) + assert result[0] == -7.0 + assert result[1] == -7 + + # --- FLOAT input --- + + def test_float_input(self): + result = self._exec(3.14) + assert result[0] == 3.14 + assert result[1] == 3 + + def test_float_truncation_toward_zero(self): + result = self._exec(-2.9) + assert result[0] == -2.9 + assert result[1] == -2 # int() truncates toward zero, not floor + + def test_float_output_type(self): + result = self._exec(5) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec(5.7) + assert isinstance(result[1], int) + + # --- BOOL input --- + + def test_bool_true(self): + result = self._exec(True) + assert result[0] == 1.0 + assert result[1] == 1 + + def test_bool_false(self): + result = self._exec(False) + assert result[0] == 0.0 + assert result[1] == 0 + + # --- STRING input --- + + def test_string_integer(self): + result = self._exec("42") + assert result[0] == 42.0 + assert result[1] == 42 + + def test_string_float(self): + result = self._exec("3.14") + assert result[0] == 3.14 + assert result[1] == 3 + + def test_string_negative(self): + result = self._exec("-5.5") + assert result[0] == -5.5 + assert result[1] == -5 + + def test_string_with_whitespace(self): + result = self._exec(" 7.0 ") + assert result[0] == 7.0 + assert result[1] == 7 + + def test_string_scientific_notation(self): + result = self._exec("1e3") + assert result[0] == 1000.0 + assert result[1] == 1000 + + # --- STRING error paths --- + + def test_empty_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert empty string"): + self._exec("") + + def test_whitespace_only_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert empty string"): + self._exec(" ") + + def test_non_numeric_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert string to number"): + self._exec("abc") + + def test_string_inf_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("inf") + + def test_string_nan_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("nan") + + def test_string_negative_inf_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("-inf") + + # --- Unsupported type --- + + def test_unsupported_type_raises(self): + with pytest.raises(TypeError, match="Unsupported input type"): + self._exec([1, 2, 3]) From 5ebb0c2e0b72945c271a2fb4db749585aa32a13c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:39:04 +0800 Subject: [PATCH 10/13] FP8 bwd training (#13121) --- comfy/model_management.py | 1 + comfy/ops.py | 65 ++++++++++++++++++++++++++++--------- comfy_extras/nodes_train.py | 9 +++++ 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..9617d8388 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,6 +55,7 @@ total_vram = 0 # Training Related State in_training = False +training_fp8_bwd = False def get_supported_float8_types(): diff --git a/comfy/ops.py b/comfy/ops.py index 1518ec9de..ca25693db 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -777,8 +777,16 @@ from .quant_ops import ( class QuantLinearFunc(torch.autograd.Function): - """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. - Handles any input rank by flattening to 2D for matmul and restoring shape after. + """Custom autograd function for quantized linear: quantized forward, optionally FP8 backward. + + When training_fp8_bwd is enabled: + - Forward: quantize input per layout (FP8/NVFP4), use quantized matmul + - Backward: all matmuls use FP8 tensor cores via torch.mm dispatch + - Cached input is FP8 (half the memory of bf16) + + When training_fp8_bwd is disabled: + - Forward: quantize input per layout, use quantized matmul + - Backward: dequantize weight to compute_dtype, use standard matmul """ @staticmethod @@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function): input_shape = input_float.shape inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D - # Quantize input (same as inference path) + # Quantize input for forward (same layout as weight) if layout_type is not None: q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) else: @@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function): output = torch.nn.functional.linear(q_input, w, b) - # Restore original input shape + # Unflatten output to match original input shape if len(input_shape) > 2: output = output.unflatten(0, input_shape[:-1]) - ctx.save_for_backward(input_float, weight) + # Save for backward ctx.input_shape = input_shape ctx.has_bias = bias is not None ctx.compute_dtype = compute_dtype ctx.weight_requires_grad = weight.requires_grad + ctx.fp8_bwd = comfy.model_management.training_fp8_bwd + + if ctx.fp8_bwd: + # Cache FP8 quantized input — half the memory of bf16 + if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'): + ctx.q_input = q_input # already FP8, reuse + else: + # NVFP4 or other layout — quantize input to FP8 for backward + ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + ctx.save_for_backward(weight) + else: + ctx.q_input = None + ctx.save_for_backward(input_float, weight) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): - input_float, weight = ctx.saved_tensors compute_dtype = ctx.compute_dtype grad_2d = grad_output.flatten(0, -2).to(compute_dtype) - # Dequantize weight to compute dtype for backward matmul - if isinstance(weight, QuantizedTensor): - weight_f = weight.dequantize().to(compute_dtype) + # Value casting — only difference between fp8 and non-fp8 paths + if ctx.fp8_bwd: + weight, = ctx.saved_tensors + # Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm + grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout") + if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"): + weight_mm = weight + elif isinstance(weight, QuantizedTensor): + weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout") + else: + weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout") + input_mm = ctx.q_input else: - weight_f = weight.to(compute_dtype) + input_float, weight = ctx.saved_tensors + # Standard tensors → torch.mm does regular matmul + grad_mm = grad_2d + if isinstance(weight, QuantizedTensor): + weight_mm = weight.dequantize().to(compute_dtype) + else: + weight_mm = weight.to(compute_dtype) + input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None - # grad_input = grad_output @ weight - grad_input = torch.mm(grad_2d, weight_f) + # Computation — same for both paths, dispatch handles the rest + grad_input = torch.mm(grad_mm, weight_mm) if len(ctx.input_shape) > 2: grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) - # grad_weight (only if weight requires grad, typically frozen for quantized training) grad_weight = None if ctx.weight_requires_grad: - input_f = input_float.flatten(0, -2).to(compute_dtype) - grad_weight = torch.mm(grad_2d.t(), input_f) + grad_weight = torch.mm(grad_mm.t(), input_mm) - # grad_bias grad_bias = None if ctx.has_bias: grad_bias = grad_2d.sum(dim=0) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0ad0acee6..df1b39fd5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode): default="bf16", tooltip="The dtype to use for lora.", ), + io.Boolean.Input( + "quantized_backward", + default=False, + tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.", + ), io.Combo.Input( "algorithm", options=list(adapter_maps.keys()), @@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode): seed, training_dtype, lora_dtype, + quantized_backward, algorithm, gradient_checkpointing, checkpoint_depth, @@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode): seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] + quantized_backward = quantized_backward[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] offloading = offloading[0] @@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode): bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] + comfy.model_management.training_fp8_bwd = quantized_backward + # Process latents based on mode if bucket_mode: latents = _process_latents_bucket_mode(latents) From 7d5534d8e516e0d4cd53d6abcdcb7f1f6d51ea97 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 24 Mar 2026 20:48:55 -0700 Subject: [PATCH 11/13] feat(assets): register output files as assets after prompt execution (#12812) --- app/assets/database/queries/__init__.py | 4 + app/assets/database/queries/asset.py | 12 + .../database/queries/asset_reference.py | 17 ++ app/assets/scanner.py | 15 ++ app/assets/seeder.py | 66 ++++- app/assets/services/__init__.py | 4 + app/assets/services/bulk_ingest.py | 3 + app/assets/services/ingest.py | 102 ++++++- main.py | 43 ++- tests-unit/assets_test/services/conftest.py | 17 +- .../assets_test/services/test_enrich.py | 11 +- .../assets_test/services/test_ingest.py | 51 +++- tests-unit/seeder_test/test_seeder.py | 183 +++++++++++++ tests/test_asset_seeder.py | 250 ++++++++++++++++++ 14 files changed, 764 insertions(+), 14 deletions(-) create mode 100644 tests/test_asset_seeder.py diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 1632937b2..9949e84e1 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -1,6 +1,7 @@ from app.assets.database.queries.asset import ( asset_exists_by_hash, bulk_insert_assets, + create_stub_asset, get_asset_by_hash, get_existing_asset_ids, reassign_asset_references, @@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import ( UnenrichedReferenceRow, bulk_insert_references_ignore_conflicts, bulk_update_enrichment_level, + count_active_siblings, bulk_update_is_missing, bulk_update_needs_verify, convert_metadata_to_rows, @@ -80,6 +82,8 @@ __all__ = [ "bulk_insert_references_ignore_conflicts", "bulk_insert_tags_and_meta", "bulk_update_enrichment_level", + "count_active_siblings", + "create_stub_asset", "bulk_update_is_missing", "bulk_update_needs_verify", "convert_metadata_to_rows", diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py index 594d1f1b2..cc7168431 100644 --- a/app/assets/database/queries/asset.py +++ b/app/assets/database/queries/asset.py @@ -78,6 +78,18 @@ def upsert_asset( return asset, created, updated +def create_stub_asset( + session: Session, + size_bytes: int, + mime_type: str | None = None, +) -> Asset: + """Create a new asset with no hash (stub for later enrichment).""" + asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None) + session.add(asset) + session.flush() + return asset + + def bulk_insert_assets( session: Session, rows: list[dict], diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 084a32512..8b90ae511 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -114,6 +114,23 @@ def get_reference_by_file_path( ) +def count_active_siblings( + session: Session, + asset_id: str, + exclude_reference_id: str, +) -> int: + """Count active (non-deleted) references to an asset, excluding one reference.""" + return ( + session.query(AssetReference) + .filter( + AssetReference.asset_id == asset_id, + AssetReference.id != exclude_reference_id, + AssetReference.deleted_at.is_(None), + ) + .count() + ) + + def reference_exists_for_asset_id( session: Session, asset_id: str, diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 4e05a97b5..ebb6869af 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -13,6 +13,7 @@ from app.assets.database.queries import ( delete_references_by_ids, ensure_tags_exist, get_asset_by_hash, + get_reference_by_id, get_references_for_prefixes, get_unenriched_references, mark_references_missing_outside_prefixes, @@ -338,6 +339,7 @@ def build_asset_specs( "metadata": metadata, "hash": asset_hash, "mime_type": mime_type, + "job_id": None, } ) tag_pool.update(tags) @@ -426,6 +428,7 @@ def enrich_asset( except OSError: return new_level + initial_mtime_ns = get_mtime_ns(stat_p) rel_fname = compute_relative_filename(file_path) mime_type: str | None = None metadata = None @@ -489,6 +492,18 @@ def enrich_asset( except Exception as e: logging.warning("Failed to hash %s: %s", file_path, e) + # Optimistic guard: if the reference's mtime_ns changed since we + # started (e.g. ingest_existing_file updated it), our results are + # stale — discard them to avoid overwriting fresh registration data. + ref = get_reference_by_id(session, reference_id) + if ref is None or ref.mtime_ns != initial_mtime_ns: + session.rollback() + logging.info( + "Ref %s mtime changed during enrichment, discarding stale result", + reference_id, + ) + return ENRICHMENT_STUB + if extract_metadata and metadata: system_metadata = metadata.to_user_metadata() set_reference_system_metadata(session, reference_id, system_metadata) diff --git a/app/assets/seeder.py b/app/assets/seeder.py index 029448464..2262928e5 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -77,7 +77,9 @@ class _AssetSeeder: """ def __init__(self) -> None: - self._lock = threading.Lock() + # RLock is required because _run_scan() drains pending work while + # holding _lock and re-enters start() which also acquires _lock. + self._lock = threading.RLock() self._state = State.IDLE self._progress: Progress | None = None self._last_progress: Progress | None = None @@ -92,6 +94,7 @@ class _AssetSeeder: self._prune_first: bool = False self._progress_callback: ProgressCallback | None = None self._disabled: bool = False + self._pending_enrich: dict | None = None def disable(self) -> None: """Disable the asset seeder, preventing any scans from starting.""" @@ -196,6 +199,42 @@ class _AssetSeeder: compute_hashes=compute_hashes, ) + def enqueue_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan now, or queue it for after the current scan. + + If the seeder is idle, starts immediately. Otherwise, the enrich + request is stored and will run automatically when the current scan + finishes. + + Args: + roots: Tuple of root types to scan + compute_hashes: If True, compute blake3 hashes + + Returns: + True if started immediately, False if queued for later + """ + with self._lock: + if self.start_enrich(roots=roots, compute_hashes=compute_hashes): + return True + if self._pending_enrich is not None: + existing_roots = set(self._pending_enrich["roots"]) + existing_roots.update(roots) + self._pending_enrich["roots"] = tuple(existing_roots) + self._pending_enrich["compute_hashes"] = ( + self._pending_enrich["compute_hashes"] or compute_hashes + ) + else: + self._pending_enrich = { + "roots": roots, + "compute_hashes": compute_hashes, + } + logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"]) + return False + def cancel(self) -> bool: """Request cancellation of the current scan. @@ -381,9 +420,13 @@ class _AssetSeeder: return marked finally: with self._lock: - self._last_progress = self._progress - self._state = State.IDLE - self._progress = None + self._reset_to_idle() + + def _reset_to_idle(self) -> None: + """Reset state to IDLE, preserving last progress. Caller must hold _lock.""" + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None def _is_cancelled(self) -> bool: """Check if cancellation has been requested.""" @@ -594,9 +637,18 @@ class _AssetSeeder: }, ) with self._lock: - self._last_progress = self._progress - self._state = State.IDLE - self._progress = None + self._reset_to_idle() + pending = self._pending_enrich + if pending is not None: + self._pending_enrich = None + if not self.start_enrich( + roots=pending["roots"], + compute_hashes=pending["compute_hashes"], + ): + logging.warning( + "Pending enrich scan could not start (roots=%s)", + pending["roots"], + ) def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: """Run phase 1: fast scan to create stub records. diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py index 11fcb4122..03990966b 100644 --- a/app/assets/services/__init__.py +++ b/app/assets/services/__init__.py @@ -23,6 +23,8 @@ from app.assets.services.ingest import ( DependencyMissingError, HashMismatchError, create_from_hash, + ingest_existing_file, + register_output_files, upload_from_temp_path, ) from app.assets.database.queries import ( @@ -72,6 +74,8 @@ __all__ = [ "delete_asset_reference", "get_asset_by_hash", "get_asset_detail", + "ingest_existing_file", + "register_output_files", "get_mtime_ns", "get_size_and_mtime_ns", "list_assets_page", diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py index 54e72730c..67aad838f 100644 --- a/app/assets/services/bulk_ingest.py +++ b/app/assets/services/bulk_ingest.py @@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict): metadata: ExtractedMetadata | None hash: str | None mime_type: str | None + job_id: str | None class AssetRow(TypedDict): @@ -60,6 +61,7 @@ class ReferenceRow(TypedDict): name: str preview_id: str | None user_metadata: dict[str, Any] | None + job_id: str | None created_at: datetime updated_at: datetime last_access_time: datetime @@ -167,6 +169,7 @@ def batch_insert_seed_assets( "name": spec["info_name"], "preview_id": None, "user_metadata": user_metadata, + "job_id": spec.get("job_id"), "created_at": current_time, "updated_at": current_time, "last_access_time": current_time, diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 90c51994f..f0b070517 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -9,6 +9,9 @@ from sqlalchemy.orm import Session import app.assets.services.hashing as hashing from app.assets.database.queries import ( add_tags_to_reference, + count_active_siblings, + create_stub_asset, + ensure_tags_exist, fetch_reference_and_asset, get_asset_by_hash, get_reference_by_file_path, @@ -23,7 +26,8 @@ from app.assets.database.queries import ( upsert_reference, validate_tags_exist, ) -from app.assets.helpers import normalize_tags +from app.assets.helpers import get_utc_now, normalize_tags +from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.path_utils import ( compute_relative_filename, @@ -130,6 +134,102 @@ def _ingest_file_from_path( ) +def register_output_files( + file_paths: Sequence[str], + user_metadata: UserMetadata = None, + job_id: str | None = None, +) -> int: + """Register a batch of output file paths as assets. + + Returns the number of files successfully registered. + """ + registered = 0 + for abs_path in file_paths: + if not os.path.isfile(abs_path): + continue + try: + if ingest_existing_file( + abs_path, user_metadata=user_metadata, job_id=job_id + ): + registered += 1 + except Exception: + logging.exception("Failed to register output: %s", abs_path) + return registered + + +def ingest_existing_file( + abs_path: str, + user_metadata: UserMetadata = None, + extra_tags: Sequence[str] = (), + owner_id: str = "", + job_id: str | None = None, +) -> bool: + """Register an existing on-disk file as an asset stub. + + If a reference already exists for this path, updates mtime_ns, job_id, + size_bytes, and resets enrichment so the enricher will re-hash it. + + For brand-new paths, inserts a stub record (hash=NULL) for immediate + UX visibility. + + Returns True if a row was inserted or updated, False otherwise. + """ + locator = os.path.abspath(abs_path) + size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) + mime_type = mimetypes.guess_type(abs_path, strict=False)[0] + name, path_tags = get_name_and_tags_from_asset_path(abs_path) + tags = list(dict.fromkeys(path_tags + list(extra_tags))) + + with create_session() as session: + existing_ref = get_reference_by_file_path(session, locator) + if existing_ref is not None: + now = get_utc_now() + existing_ref.mtime_ns = mtime_ns + existing_ref.job_id = job_id + existing_ref.is_missing = False + existing_ref.deleted_at = None + existing_ref.updated_at = now + existing_ref.enrichment_level = 0 + + asset = existing_ref.asset + if asset: + # If other refs share this asset, detach to a new stub + # instead of mutating the shared row. + siblings = count_active_siblings(session, asset.id, existing_ref.id) + if siblings > 0: + new_asset = create_stub_asset( + session, + size_bytes=size_bytes, + mime_type=mime_type or asset.mime_type, + ) + existing_ref.asset_id = new_asset.id + else: + asset.hash = None + asset.size_bytes = size_bytes + if mime_type: + asset.mime_type = mime_type + session.commit() + return True + + spec = { + "abs_path": abs_path, + "size_bytes": size_bytes, + "mtime_ns": mtime_ns, + "info_name": name, + "tags": tags, + "fname": os.path.basename(abs_path), + "metadata": None, + "hash": None, + "mime_type": mime_type, + "job_id": job_id, + } + if tags: + ensure_tags_exist(session, tags) + result = batch_insert_seed_assets(session, [spec], owner_id=owner_id) + session.commit() + return result.won_paths > 0 + + def _register_existing_asset( asset_hash: str, name: str, diff --git a/main.py b/main.py index cd4483c67..058e8e2de 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +from app.assets.seeder import asset_seeder +from app.assets.services import register_output_files import itertools import utils.extra_config from utils.mime_types import init_mime_types @@ -192,7 +194,6 @@ if 'torch' in sys.modules: import comfy.utils -from app.assets.seeder import asset_seeder import execution import server @@ -240,6 +241,38 @@ def cuda_malloc_warning(): logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") +def _collect_output_absolute_paths(history_result: dict) -> list[str]: + """Extract absolute file paths for output items from a history result.""" + paths: list[str] = [] + seen: set[str] = set() + for node_output in history_result.get("outputs", {}).values(): + for items in node_output.values(): + if not isinstance(items, list): + continue + for item in items: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type not in ("output", "temp"): + continue + base_dir = folder_paths.get_directory_by_type(item_type) + if base_dir is None: + continue + base_dir = os.path.abspath(base_dir) + filename = item.get("filename") + if not filename: + continue + abs_path = os.path.abspath( + os.path.join(base_dir, item.get("subfolder", ""), filename) + ) + if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir: + continue + if abs_path not in seen: + seen.add(abs_path) + paths.append(abs_path) + return paths + + def prompt_worker(q, server_instance): current_time: float = 0.0 cache_type = execution.CacheType.CLASSIC @@ -274,6 +307,7 @@ def prompt_worker(q, server_instance): asset_seeder.pause() e.execute(item[2], prompt_id, extra_data, item[4]) + need_gc = True remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] @@ -296,6 +330,10 @@ def prompt_worker(q, server_instance): else: logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + if not asset_seeder.is_disabled(): + paths = _collect_output_absolute_paths(e.history_result) + register_output_files(paths, job_id=prompt_id) + flags = q.get_flags() free_memory = flags.get("free_memory", False) @@ -317,6 +355,9 @@ def prompt_worker(q, server_instance): last_gc_collect = current_time need_gc = False hook_breaker_ac10a0.restore_functions() + + if not asset_seeder.is_disabled(): + asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True) asset_seeder.resume() diff --git a/tests-unit/assets_test/services/conftest.py b/tests-unit/assets_test/services/conftest.py index 31c763d48..bc0723e61 100644 --- a/tests-unit/assets_test/services/conftest.py +++ b/tests-unit/assets_test/services/conftest.py @@ -3,7 +3,7 @@ from pathlib import Path from unittest.mock import patch import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.orm import Session from app.assets.database.models import Base @@ -23,6 +23,21 @@ def db_engine(): return engine +@pytest.fixture +def db_engine_fk(): + """In-memory SQLite engine with foreign key enforcement enabled.""" + engine = create_engine("sqlite:///:memory:") + + @event.listens_for(engine, "connect") + def _set_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + return engine + + @pytest.fixture def session(db_engine): """Session fixture for tests that need direct DB access.""" diff --git a/tests-unit/assets_test/services/test_enrich.py b/tests-unit/assets_test/services/test_enrich.py index 2bd79a01a..6a6561f7f 100644 --- a/tests-unit/assets_test/services/test_enrich.py +++ b/tests-unit/assets_test/services/test_enrich.py @@ -1,9 +1,11 @@ """Tests for asset enrichment (mime_type and hash population).""" +import os from pathlib import Path from sqlalchemy.orm import Session from app.assets.database.models import Asset, AssetReference +from app.assets.services.file_utils import get_mtime_ns from app.assets.scanner import ( ENRICHMENT_HASHED, ENRICHMENT_METADATA, @@ -20,6 +22,13 @@ def _create_stub_asset( name: str | None = None, ) -> tuple[Asset, AssetReference]: """Create a stub asset with reference for testing enrichment.""" + # Use the real file's mtime so the optimistic guard in enrich_asset passes + try: + stat_result = os.stat(file_path, follow_symlinks=True) + mtime_ns = get_mtime_ns(stat_result) + except OSError: + mtime_ns = 1234567890000000000 + asset = Asset( id=asset_id, hash=None, @@ -35,7 +44,7 @@ def _create_stub_asset( name=name or f"test-asset-{asset_id}", owner_id="system", file_path=file_path, - mtime_ns=1234567890000000000, + mtime_ns=mtime_ns, enrichment_level=ENRICHMENT_STUB, ) session.add(ref) diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index dbb8441c2..b153f9795 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -1,12 +1,18 @@ """Tests for ingest services.""" +from contextlib import contextmanager from pathlib import Path +from unittest.mock import patch import pytest -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session as SASession, Session -from app.assets.database.models import Asset, AssetReference, Tag +from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag from app.assets.database.queries import get_reference_tags -from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset +from app.assets.services.ingest import ( + _ingest_file_from_path, + _register_existing_asset, + ingest_existing_file, +) class TestIngestFileFromPath: @@ -235,3 +241,42 @@ class TestRegisterExistingAsset: assert result.created is True assert set(result.tags) == {"alpha", "beta"} + + +class TestIngestExistingFileTagFK: + """Regression: ingest_existing_file must seed Tag rows before inserting + AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError.""" + + def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path): + """With PRAGMA foreign_keys=ON, tags must exist in the tags table + before they can be referenced in asset_reference_tags.""" + + @contextmanager + def _create_session(): + with SASession(db_engine_fk) as sess: + yield sess + + file_path = temp_dir / "output.png" + file_path.write_bytes(b"image data") + + with patch("app.assets.services.ingest.create_session", _create_session), \ + patch( + "app.assets.services.ingest.get_name_and_tags_from_asset_path", + return_value=("output.png", ["output"]), + ): + result = ingest_existing_file( + abs_path=str(file_path), + extra_tags=["my-job"], + ) + + assert result is True + + with SASession(db_engine_fk) as sess: + tag_names = {t.name for t in sess.query(Tag).all()} + assert "output" in tag_names + assert "my-job" in tag_names + + ref_tags = sess.query(AssetReferenceTag).all() + ref_tag_names = {rt.tag_name for rt in ref_tags} + assert "output" in ref_tag_names + assert "my-job" in ref_tag_names diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py index db3795e48..6aed6d6f3 100644 --- a/tests-unit/seeder_test/test_seeder.py +++ b/tests-unit/seeder_test/test_seeder.py @@ -1,6 +1,7 @@ """Unit tests for the _AssetSeeder background scanning class.""" import threading +import time from unittest.mock import patch import pytest @@ -771,6 +772,188 @@ class TestSeederStopRestart: assert collected_roots[1] == ("input",) +class TestEnqueueEnrichHandoff: + """Test that the drain of _pending_enrich is atomic with start_enrich.""" + + def test_pending_enrich_runs_after_scan_completes( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + """A queued enrich request runs automatically when a scan finishes.""" + enrich_roots_seen: list[tuple] = [] + original_start = fresh_seeder.start + + def tracking_start(*args, **kwargs): + phase = kwargs.get("phase") + roots = kwargs.get("roots", args[0] if args else None) + result = original_start(*args, **kwargs) + if phase == ScanPhase.ENRICH and result: + enrich_roots_seen.append(roots) + return result + + fresh_seeder.start = tracking_start + + # Start a fast scan, then enqueue an enrich while it's running + barrier = threading.Event() + reached = threading.Event() + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert reached.wait(timeout=2.0) + + queued = fresh_seeder.enqueue_enrich( + roots=("input",), compute_hashes=True + ) + assert queued is False # queued, not started immediately + + barrier.set() + + # Wait for the original scan + the auto-started enrich scan + deadline = time.monotonic() + 5.0 + while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline: + time.sleep(0.05) + + assert enrich_roots_seen == [("input",)] + + def test_enqueue_enrich_during_drain_does_not_lose_work( + self, fresh_seeder: _AssetSeeder, mock_dependencies + ): + """enqueue_enrich called concurrently with drain cannot drop work. + + Simulates the race: another thread calls enqueue_enrich right as the + scan thread is draining _pending_enrich. The enqueue must either be + picked up by the draining scan or successfully start its own scan. + """ + barrier = threading.Event() + reached = threading.Event() + enrich_started = threading.Event() + + enrich_call_count = 0 + + def slow_collect(*args): + reached.set() + barrier.wait(timeout=5.0) + return [] + + # Track how many times start_enrich actually fires + real_start_enrich = fresh_seeder.start_enrich + enrich_roots_seen: list[tuple] = [] + + def tracking_start_enrich(**kwargs): + nonlocal enrich_call_count + enrich_call_count += 1 + enrich_roots_seen.append(kwargs.get("roots")) + result = real_start_enrich(**kwargs) + if result: + enrich_started.set() + return result + + fresh_seeder.start_enrich = tracking_start_enrich + + with patch( + "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect + ): + # Start a scan + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert reached.wait(timeout=2.0) + + # Queue an enrich while scan is running + fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False) + + # Let scan finish — drain will fire start_enrich atomically + barrier.set() + + # Wait for drain to complete and the enrich scan to start + assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain" + assert ("output",) in enrich_roots_seen + + def test_concurrent_enqueue_during_drain_not_lost( + self, fresh_seeder: _AssetSeeder, + ): + """A second enqueue_enrich arriving while drain is in progress is not lost. + + Because the drain now holds _lock through the start_enrich call, + a concurrent enqueue_enrich will block until start_enrich has + transitioned state to RUNNING, then the enqueue will queue its + payload as _pending_enrich for the *next* drain. + """ + scan_barrier = threading.Event() + scan_reached = threading.Event() + enrich_barrier = threading.Event() + enrich_reached = threading.Event() + + collect_call = 0 + + def gated_collect(*args): + nonlocal collect_call + collect_call += 1 + if collect_call == 1: + # First call: the initial fast scan + scan_reached.set() + scan_barrier.wait(timeout=5.0) + return [] + + enrich_call = 0 + + def gated_get_unenriched(*args, **kwargs): + nonlocal enrich_call + enrich_call += 1 + if enrich_call == 1: + # First enrich batch: signal and block + enrich_reached.set() + enrich_barrier.wait(timeout=5.0) + return [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), + ): + # 1. Start fast scan + fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST) + assert scan_reached.wait(timeout=2.0) + + # 2. Queue enrich while fast scan is running + queued = fresh_seeder.enqueue_enrich( + roots=("input",), compute_hashes=False + ) + assert queued is False + + # 3. Let the fast scan finish — drain will start the enrich scan + scan_barrier.set() + + # 4. Wait until the drained enrich scan is running + assert enrich_reached.wait(timeout=5.0) + + # 5. Now enqueue another enrich while the drained scan is running + queued2 = fresh_seeder.enqueue_enrich( + roots=("output",), compute_hashes=True + ) + assert queued2 is False # should be queued, not started + + # Verify _pending_enrich was set (the second enqueue was captured) + with fresh_seeder._lock: + assert fresh_seeder._pending_enrich is not None + assert "output" in fresh_seeder._pending_enrich["roots"] + + # Let the enrich scan finish + enrich_barrier.set() + + deadline = time.monotonic() + 5.0 + while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline: + time.sleep(0.05) + + def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: return UnenrichedReferenceRow( reference_id=ref_id, asset_id=asset_id, diff --git a/tests/test_asset_seeder.py b/tests/test_asset_seeder.py new file mode 100644 index 000000000..4274dab8e --- /dev/null +++ b/tests/test_asset_seeder.py @@ -0,0 +1,250 @@ +"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour.""" + +import threading +from unittest.mock import patch + +import pytest + +from app.assets.seeder import Progress, _AssetSeeder, State + + +@pytest.fixture() +def seeder(): + """Fresh seeder instance for each test.""" + return _AssetSeeder() + + +# --------------------------------------------------------------------------- +# _reset_to_idle +# --------------------------------------------------------------------------- + + +class TestResetToIdle: + def test_sets_idle_and_clears_progress(self, seeder): + """_reset_to_idle should move state to IDLE and snapshot progress.""" + progress = Progress(scanned=10, total=20, created=5, skipped=3) + seeder._state = State.RUNNING + seeder._progress = progress + + with seeder._lock: + seeder._reset_to_idle() + + assert seeder._state is State.IDLE + assert seeder._progress is None + assert seeder._last_progress is progress + + def test_noop_when_progress_already_none(self, seeder): + """_reset_to_idle should handle None progress gracefully.""" + seeder._state = State.CANCELLING + seeder._progress = None + + with seeder._lock: + seeder._reset_to_idle() + + assert seeder._state is State.IDLE + assert seeder._progress is None + assert seeder._last_progress is None + + +# --------------------------------------------------------------------------- +# enqueue_enrich – immediate start when idle +# --------------------------------------------------------------------------- + + +class TestEnqueueEnrichStartsImmediately: + def test_starts_when_idle(self, seeder): + """enqueue_enrich should delegate to start_enrich and return True when idle.""" + with patch.object(seeder, "start_enrich", return_value=True) as mock: + assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True + mock.assert_called_once_with(roots=("output",), compute_hashes=True) + + def test_no_pending_when_started_immediately(self, seeder): + """No pending request should be stored when start_enrich succeeds.""" + with patch.object(seeder, "start_enrich", return_value=True): + seeder.enqueue_enrich(roots=("output",)) + assert seeder._pending_enrich is None + + +# --------------------------------------------------------------------------- +# enqueue_enrich – queuing when busy +# --------------------------------------------------------------------------- + + +class TestEnqueueEnrichQueuesWhenBusy: + def test_queues_when_busy(self, seeder): + """enqueue_enrich should store a pending request when seeder is busy.""" + with patch.object(seeder, "start_enrich", return_value=False): + result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False) + + assert result is False + assert seeder._pending_enrich == { + "roots": ("models",), + "compute_hashes": False, + } + + def test_queues_preserves_compute_hashes_true(self, seeder): + with patch.object(seeder, "start_enrich", return_value=False): + seeder.enqueue_enrich(roots=("input",), compute_hashes=True) + + assert seeder._pending_enrich["compute_hashes"] is True + + +# --------------------------------------------------------------------------- +# enqueue_enrich – merging when a pending request already exists +# --------------------------------------------------------------------------- + + +class TestEnqueueEnrichMergesPending: + def _make_busy(self, seeder): + """Patch start_enrich to always return False (seeder busy).""" + return patch.object(seeder, "start_enrich", return_value=False) + + def test_merges_roots(self, seeder): + """A second enqueue should merge roots with the existing pending request.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models",)) + seeder.enqueue_enrich(roots=("output",)) + + merged = set(seeder._pending_enrich["roots"]) + assert merged == {"models", "output"} + + def test_merges_overlapping_roots(self, seeder): + """Duplicate roots should be deduplicated.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models", "input")) + seeder.enqueue_enrich(roots=("input", "output")) + + merged = set(seeder._pending_enrich["roots"]) + assert merged == {"models", "input", "output"} + + def test_compute_hashes_sticky_true(self, seeder): + """Once compute_hashes is True it should stay True after merging.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models",), compute_hashes=True) + seeder.enqueue_enrich(roots=("output",), compute_hashes=False) + + assert seeder._pending_enrich["compute_hashes"] is True + + def test_compute_hashes_upgrades_to_true(self, seeder): + """A later enqueue with compute_hashes=True should upgrade the pending request.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models",), compute_hashes=False) + seeder.enqueue_enrich(roots=("output",), compute_hashes=True) + + assert seeder._pending_enrich["compute_hashes"] is True + + def test_compute_hashes_stays_false(self, seeder): + """If both enqueues have compute_hashes=False it stays False.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models",), compute_hashes=False) + seeder.enqueue_enrich(roots=("output",), compute_hashes=False) + + assert seeder._pending_enrich["compute_hashes"] is False + + def test_triple_merge(self, seeder): + """Three successive enqueues should all merge correctly.""" + with self._make_busy(seeder): + seeder.enqueue_enrich(roots=("models",), compute_hashes=False) + seeder.enqueue_enrich(roots=("input",), compute_hashes=False) + seeder.enqueue_enrich(roots=("output",), compute_hashes=True) + + merged = set(seeder._pending_enrich["roots"]) + assert merged == {"models", "input", "output"} + assert seeder._pending_enrich["compute_hashes"] is True + + +# --------------------------------------------------------------------------- +# Pending enrich drains after scan completes +# --------------------------------------------------------------------------- + + +class TestPendingEnrichDrain: + """Verify that _run_scan drains _pending_enrich via start_enrich.""" + + @patch("app.assets.seeder.dependencies_available", return_value=True) + @patch("app.assets.seeder.get_all_known_prefixes", return_value=[]) + @patch("app.assets.seeder.sync_root_safely", return_value=set()) + @patch("app.assets.seeder.collect_paths_for_roots", return_value=[]) + @patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0)) + def test_pending_enrich_starts_after_scan(self, *_mocks): + """After a fast scan finishes, the pending enrich should be started.""" + seeder = _AssetSeeder() + + seeder._pending_enrich = { + "roots": ("output",), + "compute_hashes": True, + } + + with patch.object(seeder, "start_enrich", return_value=True) as mock_start: + seeder.start_fast(roots=("models",)) + seeder.wait(timeout=5) + + mock_start.assert_called_once_with( + roots=("output",), + compute_hashes=True, + ) + + assert seeder._pending_enrich is None + + @patch("app.assets.seeder.dependencies_available", return_value=True) + @patch("app.assets.seeder.get_all_known_prefixes", return_value=[]) + @patch("app.assets.seeder.sync_root_safely", return_value=set()) + @patch("app.assets.seeder.collect_paths_for_roots", return_value=[]) + @patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0)) + def test_pending_cleared_even_when_start_fails(self, *_mocks): + """_pending_enrich should be cleared even if start_enrich returns False.""" + seeder = _AssetSeeder() + seeder._pending_enrich = { + "roots": ("output",), + "compute_hashes": False, + } + + with patch.object(seeder, "start_enrich", return_value=False): + seeder.start_fast(roots=("models",)) + seeder.wait(timeout=5) + + assert seeder._pending_enrich is None + + @patch("app.assets.seeder.dependencies_available", return_value=True) + @patch("app.assets.seeder.get_all_known_prefixes", return_value=[]) + @patch("app.assets.seeder.sync_root_safely", return_value=set()) + @patch("app.assets.seeder.collect_paths_for_roots", return_value=[]) + @patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0)) + def test_no_drain_when_no_pending(self, *_mocks): + """start_enrich should not be called when there is no pending request.""" + seeder = _AssetSeeder() + assert seeder._pending_enrich is None + + with patch.object(seeder, "start_enrich", return_value=True) as mock_start: + seeder.start_fast(roots=("models",)) + seeder.wait(timeout=5) + + mock_start.assert_not_called() + + +# --------------------------------------------------------------------------- +# Thread-safety of enqueue_enrich +# --------------------------------------------------------------------------- + + +class TestEnqueueEnrichThreadSafety: + def test_concurrent_enqueues(self, seeder): + """Multiple threads enqueuing should not lose roots.""" + with patch.object(seeder, "start_enrich", return_value=False): + barrier = threading.Barrier(3) + + def enqueue(root): + barrier.wait() + seeder.enqueue_enrich(roots=(root,), compute_hashes=False) + + threads = [ + threading.Thread(target=enqueue, args=(r,)) + for r in ("models", "input", "output") + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + merged = set(seeder._pending_enrich["roots"]) + assert merged == {"models", "input", "output"} From b53b10ea61ef7fc54fbde7c1e7b7c36565bacf82 Mon Sep 17 00:00:00 2001 From: Krishna Chaitanya Date: Tue, 24 Mar 2026 20:53:44 -0700 Subject: [PATCH 12/13] Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145) When training_dtype is set to "none" and the model's native dtype is float16, GradScaler was unconditionally enabled. However, GradScaler does not support bfloat16 gradients (only float16/float32), causing a NotImplementedError when lora_dtype is "bf16" (the default). Fix by only enabling GradScaler when LoRA parameters are not in bfloat16, since bfloat16 has the same exponent range as float32 and does not need gradient scaling to avoid underflow. Fixes #13124 --- comfy_extras/nodes_train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index df1b39fd5..0616dfc2d 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1146,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() use_grad_scaler = False + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) if training_dtype != "none": dtype = node_helpers.string_to_torch_dtype(training_dtype) mp.set_model_compute_dtype(dtype) @@ -1154,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode): model_dtype = mp.model.get_dtype() if model_dtype == torch.float16: dtype = torch.float16 - use_grad_scaler = True + # GradScaler only supports float16 gradients, not bfloat16. + # Only enable it when lora params will also be in float16. + if lora_dtype != torch.bfloat16: + use_grad_scaler = True # Warn about fp16 accumulation instability during training if PerformanceFeature.Fp16Accumulation in args.fast: logging.warning( @@ -1165,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode): else: # For fp8, bf16, or other dtypes, use bf16 autocast dtype = torch.bfloat16 - lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) # Prepare latents and compute counts latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 From a55835f10c29c1acdc9158bf9e092656ae1a2188 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 25 Mar 2026 20:05:49 +0200 Subject: [PATCH 13/13] fix(api-nodes): made Reve node price badges more precise (#13154) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_reve.py | 43 +++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py index 608d9f058..a87395394 100644 --- a/comfy_api_nodes/nodes_reve.py +++ b/comfy_api_nodes/nodes_reve.py @@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""", + depends_on=IO.PriceBadgeDepends( + widgets=["upscale", "upscale.upscale_factor"], + ), + expr=""" + ( + $factor := $lookup(widgets, "upscale.upscale_factor"); + $fmt := {"approximate": true, "note": "(base)"}; + widgets.upscale = "enabled" ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt} + : {"type": "usd", "usd": 0.0457, "format": $fmt} + ) : {"type": "usd", "usd": 0.03432, "format": $fmt} + ) + """, ), ) @@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode): is_api_node=True, price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends( - widgets=["model"], + widgets=["model", "upscale", "upscale.upscale_factor"], ), expr=""" ( + $fmt := {"approximate": true, "note": "(base)"}; $isFast := $contains(widgets.model, "fast"); - $base := $isFast ? 0.01001 : 0.0572; - {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + $enabled := widgets.upscale = "enabled"; + $factor := $lookup(widgets, "upscale.upscale_factor"); + $isFast + ? {"type": "usd", "usd": 0.01001, "format": $fmt} + : $enabled ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt} + : {"type": "usd", "usd": 0.0686, "format": $fmt} + ) : {"type": "usd", "usd": 0.0572, "format": $fmt} ) """, ), @@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode): is_api_node=True, price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends( - widgets=["model"], + widgets=["model", "upscale", "upscale.upscale_factor"], ), expr=""" ( + $fmt := {"approximate": true, "note": "(base)"}; $isFast := $contains(widgets.model, "fast"); - $base := $isFast ? 0.01001 : 0.0572; - {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + $enabled := widgets.upscale = "enabled"; + $factor := $lookup(widgets, "upscale.upscale_factor"); + $isFast + ? {"type": "usd", "usd": 0.01001, "format": $fmt} + : $enabled ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt} + : {"type": "usd", "usd": 0.0686, "format": $fmt} + ) : {"type": "usd", "usd": 0.0572, "format": $fmt} ) """, ),