From e75f775ae8e9b1a1fd2b78806c86338fd830bcd7 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 21 Apr 2026 16:43:11 +0900 Subject: [PATCH 01/35] Bump comfyui-frontend-package to 1.42.12 (#13489) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 63d7c41cf..671bd5693 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.42.11 +comfyui-frontend-package==1.42.12 comfyui-workflow-templates==0.9.57 comfyui-embedded-docs==0.4.3 torch From ad94d472216ba52ab2660536af44faa92cf4b5d0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 21 Apr 2026 08:02:42 -0700 Subject: [PATCH 02/35] Make the ltx audio vae more native. (#13486) --- comfy/ldm/lightricks/vae/audio_vae.py | 55 +++------------------------ comfy/sd.py | 18 +++++++++ comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_lt_audio.py | 36 ++++++++---------- 4 files changed, 41 insertions(+), 70 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748..dd5320c8f 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -4,9 +4,6 @@ import math import torch import torchaudio -import comfy.model_management -import comfy.model_patcher -import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( @@ -43,30 +40,6 @@ class AudioVAEComponentConfig: return cls(autoencoder=audio_config, vocoder=vocoder_config) - -class ModelDeviceManager: - """Manages device placement and GPU residency for the composed model.""" - - def __init__(self, module: torch.nn.Module): - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) - - def ensure_model_loaded(self) -> None: - comfy.model_management.free_memory( - self.patcher.model_size(), - self.patcher.load_device, - ) - comfy.model_management.load_model_gpu(self.patcher) - - def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(self.patcher.load_device) - - @property - def load_device(self): - return self.patcher.load_device - - class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" @@ -132,23 +105,17 @@ class AudioPreprocessor: class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" - def __init__(self, state_dict: dict, metadata: dict): + def __init__(self, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) - vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) - vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) - self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) if "bwe" in component_config.vocoder: self.vocoder = VocoderWithBWE(config=component_config.vocoder) else: self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) - autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( @@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module): n_fft=autoencoder_config["n_fft"], ) - self.device_manager = ModelDeviceManager(self) - - def encode(self, audio: dict) -> torch.Tensor: + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" - waveform = audio["waveform"] - waveform_sample_rate = audio["sample_rate"] + waveform = audio + waveform_sample_rate = sample_rate input_device = waveform.device - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: if waveform.shape[1] == 1: @@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module): ) mel_spec = self.preprocessor.waveform_to_mel( - waveform, waveform_sample_rate, device=self.device_manager.load_device + waveform, waveform_sample_rate, device=waveform.device ) latents = self.autoencoder.encode(mel_spec) @@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module): """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) - return self.device_manager.move_to_load_device(waveform) + return waveform def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..a4d3ee269 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.lightricks.vae.audio_vae import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 @@ -805,6 +806,23 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) + self.latent_channels = self.first_stage_model.latent_channels + self.audio_sample_rate_output = self.first_stage_model.output_sample_rate + self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes + self.output_channels = 2 + self.pad_channel_value = "replicate" + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 2 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a395392d8..5f514716f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None): std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264..3ec635c75 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -3,9 +3,8 @@ import comfy.utils import comfy.model_management import torch -from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io - +from comfy_extras.nodes_audio import VAEEncodeAudio class LTXVAudioVAELoader(io.ComfyNode): @classmethod @@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode): def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - return io.NodeOutput(AudioVAE(sd, metadata)) + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) + vae.throw_exception_if_invalid() + + return io.NodeOutput(vae) -class LTXVAudioVAEEncode(io.ComfyNode): +class LTXVAudioVAEEncode(VAEEncodeAudio): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( @@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode): ) @classmethod - def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: - audio_latents = audio_vae.encode(audio) - return io.NodeOutput( - { - "samples": audio_latents, - "sample_rate": int(audio_vae.sample_rate), - "type": "audio", - } - ) + def execute(cls, audio, audio_vae) -> io.NodeOutput: + return super().execute(audio_vae, audio) class LTXVAudioVAEDecode(io.ComfyNode): @@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode): ) @classmethod - def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + def execute(cls, samples, audio_vae) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] - audio = audio_vae.decode(audio_latent).to(audio_latent.device) - output_audio_sample_rate = audio_vae.output_sample_rate + audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device) + output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate return io.NodeOutput( { "waveform": audio, @@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode): frames_number: int, frame_rate: int, batch_size: int, - audio_vae: AudioVAE, + audio_vae, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels - audio_freq = audio_vae.latent_frequency_bins - sampling_rate = int(audio_vae.sample_rate) + audio_freq = audio_vae.first_stage_model.latent_frequency_bins + sampling_rate = int(audio_vae.first_stage_model.sample_rate) - num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq), From b38dd0ff23037cd4f03f12274a5c7e04d224febd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:45:10 +0300 Subject: [PATCH 03/35] feat(api-nodes): add automatic downscaling of videos for ByteDance 2 nodes (#13465) --- comfy_api_nodes/apis/bytedance.py | 13 +++- comfy_api_nodes/nodes_bytedance.py | 33 +++++++-- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/conversions.py | 104 +++++++++++++++++++++++++--- 4 files changed, 134 insertions(+), 18 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 3755323ac..dc3bc3213 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -158,10 +158,17 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [ ("Custom", None, None), ] -# Seedance 2.0 reference video pixel count limits per model. +# Seedance 2.0 reference video pixel count limits per model and output resolution. SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = { - "dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408}, - "dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408}, + "dreamina-seedance-2-0-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + "1080p": {"min": 409_600, "max": 2_073_600}, + }, + "dreamina-seedance-2-0-fast-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + }, } # The time in this dictionary are given for 10 seconds duration. diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 429c32444..bc564782d 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -35,6 +35,7 @@ from comfy_api_nodes.util import ( get_number_of_images, image_tensor_pair_to_batch, poll_op, + resize_video_to_pixel_budget, sync_op, upload_audio_to_comfyapi, upload_image_to_comfyapi, @@ -69,9 +70,12 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504 logger = logging.getLogger(__name__) -def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None: - """Validate reference video pixel count against Seedance 2.0 model limits.""" - limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) +def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None: + """Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution.""" + model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) + if not model_limits: + return + limits = model_limits.get(resolution) if not limits: return try: @@ -1373,6 +1377,14 @@ def _seedance2_reference_inputs(resolutions: list[str]): min=0, ), ), + IO.Boolean.Input( + "auto_downscale", + default=False, + advanced=True, + optional=True, + tooltip="Automatically downscale reference videos that exceed the model's pixel budget " + "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", + ), ] @@ -1480,10 +1492,23 @@ class ByteDance2ReferenceNode(IO.ComfyNode): model_id = SEEDANCE_MODELS[model["model"]] has_video_input = len(reference_videos) > 0 + + if model.get("auto_downscale") and reference_videos: + max_px = ( + SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}) + .get(model["resolution"], {}) + .get("max") + ) + if max_px: + for key in reference_videos: + reference_videos[key] = resize_video_to_pixel_budget( + reference_videos[key], max_px + ) + total_video_duration = 0.0 for i, key in enumerate(reference_videos, 1): video = reference_videos[key] - _validate_ref_video_pixels(video, model_id, i) + _validate_ref_video_pixels(video, model_id, model["resolution"], i) try: dur = video.get_duration() if dur < 1.8: diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cb9a47c7..f3584aba9 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -19,6 +19,7 @@ from .conversions import ( image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, + resize_video_to_pixel_budget, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, @@ -90,6 +91,7 @@ __all__ = [ "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", + "resize_video_to_pixel_budget", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 82b6d22a5..be5d5719b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr +def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None: + """Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits. + + Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions + are rounded down to even values (many codecs require divisible-by-2). + """ + pixels = src_w * src_h + if pixels <= total_pixels: + return None + scale = math.sqrt(total_pixels / pixels) + new_w = max(2, int(src_w * scale)) + new_h = max(2, int(src_h * scale)) + new_w -= new_w % 2 + new_h -= new_h % 2 + return new_w, new_h + + def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" + """Downscale input image tensor to roughly the specified total pixels. + + Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels`` + and is compatible with codecs that require even dimensions (e.g. yuv420p). + """ samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: + dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels)) + if dims is None: return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s + new_w, new_h = dims + return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1) -def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: """Downscale input image tensor so the largest dimension is at most max_side pixels.""" samples = image.movedim(-1, 1) height, width = samples.shape[2], samples.shape[3] @@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: raise RuntimeError(f"Failed to trim video: {str(e)}") from e +def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video: + """Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio. + + Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio. + Aspect ratio is preserved up to a fraction of a percent (even-dim rounding). + """ + src_w, src_h = video.get_dimensions() + scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels) + if scale_dims is None: + return video + return _apply_video_scale(video, scale_dims) + + +def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video: + """Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass.""" + out_w, out_h = scale_dims + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + input_source = video.get_stream_source() + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + video_stream = output_container.add_stream("h264", rate=video.get_frame_rate()) + video_stream.width = out_w + video_stream.height = out_h + video_stream.pix_fmt = "yuv420p" + + audio_stream = None + for stream in input_container.streams: + if isinstance(stream, av.AudioStream): + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + break + + for frame in input_container.decode(video=0): + frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + for packet in video_stream.encode(frame): + output_container.mux(packet) + for packet in video_stream.encode(): + output_container.mux(packet) + + if audio_stream is not None: + input_container.seek(0) + for audio_frame in input_container.decode(audio=0): + for packet in audio_stream.encode(audio_frame): + output_container.mux(packet) + for packet in audio_stream.encode(): + output_container.mux(packet) + + output_container.close() + input_container.close() + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to resize video: {str(e)}") from e + + def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" if wav.dtype.is_floating_point: From eb2222538739c4ebd396cd0a40cb6d80befd04fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:46:37 +0300 Subject: [PATCH 04/35] Support standalone LTXV audio VAEs (#13499) --- comfy/sd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/sd.py b/comfy/sd.py index a4d3ee269..736fe35de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -807,6 +807,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."}) self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) From 1e1d4f12548ceb296e61af1dff217bb8e4414345 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 21 Apr 2026 21:27:35 +0300 Subject: [PATCH 05/35] [Partner Nodes] added 4K resolution for Veo models; added Veo 3 Lite model (#13330) * feat(api nodes): added 4K resolution for Veo models; added Veo 3 Lite model Signed-off-by: bigcat88 * increase poll_interval from 5 to 9 --------- Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_veo2.py | 171 +++++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 43 deletions(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13fc1cc36..084b086a8 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -24,8 +24,9 @@ from comfy_api_nodes.util import ( AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { "veo-2.0-generate-001": "veo-2.0-generate-001", - "veo-3.1-generate": "veo-3.1-generate-preview", - "veo-3.1-fast-generate": "veo-3.1-fast-generate-preview", + "veo-3.1-generate": "veo-3.1-generate-001", + "veo-3.1-fast-generate": "veo-3.1-fast-generate-001", + "veo-3.1-lite": "veo-3.1-lite-generate-001", "veo-3.0-generate-001": "veo-3.0-generate-001", "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } @@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): raise Exception("Video generation completed but no video was returned") -class Veo3VideoGenerationNode(VeoVideoGenerationNode): - """ - Generates videos from text prompts using Google's Veo 3 API. - - Supported models: - - veo-3.0-generate-001 - - veo-3.0-fast-generate-001 - - This node extends the base Veo node with Veo 3 specific features including - audio generation and fixed 8-second duration. - """ +class Veo3VideoGenerationNode(IO.ComfyNode): + """Generates videos from text prompts using Google's Veo 3 API.""" @classmethod def define_schema(cls): @@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): default="16:9", tooltip="Aspect ratio of the output video", ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "4k"], + default="720p", + tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.", + optional=True, + ), IO.String.Input( "negative_prompt", multiline=True, @@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Int.Input( "duration_seconds", default=8, - min=8, + min=4, max=8, - step=1, + step=2, display_mode=IO.NumberDisplay.number, - tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + tooltip="Duration of the output video in seconds", optional=True, ), IO.Boolean.Input( @@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): options=[ "veo-3.1-generate", "veo-3.1-fast-generate", + "veo-3.1-lite", "veo-3.0-generate-001", "veo-3.0-fast-generate-001", ], - default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", optional=True, ), @@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]), expr=""" ( $m := widgets.model; + $r := widgets.resolution; $a := widgets.generate_audio; - ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) - ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} - : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) - ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} - : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + $seconds := widgets.duration_seconds; + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03)) + : $contains($m, "3.1-fast") + ? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08)) + : $contains($m, "3.1-generate") + ? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20)) + : $contains($m, "3.0-fast") + ? ($a ? 0.15 : 0.10) + : ($a ? 0.40 : 0.20); + {"type":"usd","usd": $pps * $seconds} ) """, ), ) + @classmethod + async def execute( + cls, + prompt, + aspect_ratio="16:9", + resolution="720p", + negative_prompt="", + duration_seconds=8, + enhance_prompt=True, + person_generation="ALLOW", + seed=0, + image=None, + model="veo-3.0-generate-001", + generate_audio=False, + ): + if "lite" in model and resolution == "4k": + raise Exception("4K resolution is not supported by the veo-3.1-lite model.") + + model = MODELS_MAP[model] + + instances = [{"prompt": prompt}] + if image is not None: + image_base64 = tensor_to_base64_string(image) + if image_base64: + instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} + + parameters = { + "aspectRatio": aspect_ratio, + "personGeneration": person_generation, + "durationSeconds": duration_seconds, + "enhancePrompt": True, + "generateAudio": generate_audio, + } + if negative_prompt: + parameters["negativePrompt"] = negative_prompt + if seed > 0: + parameters["seed"] = seed + if "veo-3.1" in model: + parameters["resolution"] = resolution + + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=instances, + parameters=parameters, + ), + ) + + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest(operationName=initial_response.name), + poll_interval=9.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + class Veo3FirstLastFrameNode(IO.ComfyNode): @@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): default="", tooltip="Negative text prompt to guide what to avoid in the video", ), - IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16"], @@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Image.Input("last_frame", tooltip="End frame"), IO.Combo.Input( "model", - options=["veo-3.1-generate", "veo-3.1-fast-generate"], - default="veo-3.1-fast-generate", + options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"], ), IO.Boolean.Input( "generate_audio", @@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]), expr=""" ( - $prices := { - "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, - "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } - }; $m := widgets.model; - $ga := (widgets.generate_audio = "true"); + $r := widgets.resolution; + $ga := widgets.generate_audio; $seconds := widgets.duration; - $modelKey := - $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : - $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : - ""; - $audioKey := $ga ? "audio" : "no_audio"; - $modelPrices := $lookup($prices, $modelKey); - $pps := $lookup($modelPrices, $audioKey); - ($pps != null) - ? {"type":"usd","usd": $pps * $seconds} - : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03)) + : $contains($m, "fast") + ? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08)) + : ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20)); + {"type":"usd","usd": $pps * $seconds} ) """, ), @@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): model: str, generate_audio: bool, ): + if "lite" in model and resolution == "4k": + raise Exception("4K resolution is not supported by the veo-3.1-lite model.") + model = MODELS_MAP[model] initial_response = await sync_op( cls, @@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): data=VeoGenVidPollRequest( operationName=initial_response.name, ), - poll_interval=5.0, + poll_interval=9.0, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) From 102773cd2c13bdbe8729fdc897031dfb61dea346 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Wed, 22 Apr 2026 03:35:45 +0900 Subject: [PATCH 06/35] Bump comfyui-frontend-package to 1.42.14 (#13493) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 671bd5693..ccdd47674 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.42.12 +comfyui-frontend-package==1.42.14 comfyui-workflow-templates==0.9.57 comfyui-embedded-docs==0.4.3 torch From 43a1263b609b923b2f69a0510bcf7ac95097e41b Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Tue, 21 Apr 2026 17:58:59 -0700 Subject: [PATCH 07/35] Add gpt-image-2 as version option (#13501) --- comfy_api_nodes/nodes_openai.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 4ee896fa8..90a29c2f2 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -363,7 +363,7 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1.5", + display_name="OpenAI GPT Image 2", category="api node/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ @@ -427,8 +427,8 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gpt-image-1", "gpt-image-1.5"], - default="gpt-image-1.5", + options=["gpt-image-1", "gpt-image-1.5", 'gpt-image-2'], + default="gpt-image-2", optional=True, ), ], @@ -487,6 +487,8 @@ class OpenAIGPTImage1(IO.ComfyNode): price_extractor = calculate_tokens_price_image_1 elif model == "gpt-image-1.5": price_extractor = calculate_tokens_price_image_1_5 + elif model == "gpt-image-2": + price_extractor = calculate_tokens_price_image_1_5 else: raise ValueError(f"Unknown model: {model}") From 529c80255f3f2370c39780c62a9454d95344014d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 21 Apr 2026 19:59:31 -0700 Subject: [PATCH 08/35] Allow logging in comfy app files. (#13505) --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 12b04719d..dbaf2745c 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools @@ -27,8 +29,6 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) - faulthandler.enable(file=sys.stderr, all_threads=False) import comfy_aimdo.control From 6045c11d8b32d5f761c555d6ca026e4d731ac8d5 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Wed, 22 Apr 2026 11:45:25 +0800 Subject: [PATCH 09/35] chore: update workflow templates to v0.9.59 (#13507) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ccdd47674..a25bc0667 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.14 -comfyui-workflow-templates==0.9.57 +comfyui-workflow-templates==0.9.59 comfyui-embedded-docs==0.4.3 torch torchsde From 91e1f45d80fba14d992269b0b98de7a4a14c81b9 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 21 Apr 2026 22:31:36 -0700 Subject: [PATCH 10/35] fix(veo): reject 4K resolution for veo-3.0 models in Veo3VideoGenerationNode (#13504) The tooltip on the resolution input states that 4K is not available for veo-3.1-lite or veo-3.0 models, but the execute guard only rejected the lite combination. Selecting 4K with veo-3.0-generate-001 or veo-3.0-fast-generate-001 would fall through and hit the upstream API with an invalid request. Broaden the guard to match the documented behavior and update the error message accordingly. Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_veo2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 084b086a8..2ff75d9b2 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -393,8 +393,8 @@ class Veo3VideoGenerationNode(IO.ComfyNode): model="veo-3.0-generate-001", generate_audio=False, ): - if "lite" in model and resolution == "4k": - raise Exception("4K resolution is not supported by the veo-3.1-lite model.") + if resolution == "4k" and ("lite" in model or "3.0" in model): + raise Exception("4K resolution is not supported by the veo-3.1-lite or veo-3.0 models.") model = MODELS_MAP[model] From db85cf03ff33f5be09d02f2a52334971209d25d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 22 Apr 2026 14:16:02 +0300 Subject: [PATCH 11/35] feat: RIFE and FILM frame interpolation model support (CORE-29) (#13258) * initial RIFE support * Also support FILM * Better RAM usage, reduce FILM VRAM peak * Add model folder placeholder * Fix oom fallback frame loss * Remove torch.compile for now * Rename model input * Shorter input type name --------- --- .../frame_interpolation_models/film_net.py | 258 ++++++++++++++++++ .../frame_interpolation_models/ifnet.py | 128 +++++++++ comfy_extras/nodes_frame_interpolation.py | 211 ++++++++++++++ folder_paths.py | 2 + .../put_frame_interpolation_models_here | 0 nodes.py | 3 +- 6 files changed, 601 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/frame_interpolation_models/film_net.py create mode 100644 comfy_extras/frame_interpolation_models/ifnet.py create mode 100644 comfy_extras/nodes_frame_interpolation.py create mode 100644 models/frame_interpolation/put_frame_interpolation_models_here diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 000000000..cf4f6e1e1 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,258 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + # Build warp targets and free full pyramids (only first fpl levels needed from here) + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled + results.append(self.fuse(aligned)) + del aligned + return torch.cat(results, dim=0) diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 000000000..03cb34c50 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 000000000..a3b00d36e --- /dev/null +++ b/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,211 @@ +import torch +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("INTERP_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), + tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], + inputs=[ + FrameInterpolationModel.Input("interp_model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, interp_model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(interp_model) + device = interp_model.load_device + dtype = interp_model.model_dtype() + inference_model = interp_model.model + + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) + align = getattr(inference_model, "pad_align", 1) + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) + t_values = [t / multiplier for t in range(1, multiplier)] + + out_dtype = model_management.intermediate_dtype() + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) + out_idx = 1 + + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + del sample + + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + prev_frame = None + + try: + for i in range(total_pairs): + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single + + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + used_multi = False + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + used_multi = True + except model_management.OOM_EXCEPTION: + model_management.soft_empty_cache() + multi_fn = None # fall through to single-timestep path + + if not used_multi: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) + out_idx += 1 + finally: + tqdm_bar.close() + + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3..80f4b291a 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/frame_interpolation/put_frame_interpolation_models_here b/models/frame_interpolation/put_frame_interpolation_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 299b3d758..bb38e07b8 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_frame_interpolation.py", ] import_failed = [] From cc6f9500a1b972e9dca14e769f4b70a8927ffa43 Mon Sep 17 00:00:00 2001 From: Octopus Date: Thu, 23 Apr 2026 06:05:43 +0800 Subject: [PATCH 12/35] fix: use Parameter assignment for Stable_Zero123 cc_projection weights (fixes #13492) (#13518) On Windows with aimdo enabled, disable_weight_init.Linear uses lazy initialization that sets weight and bias to None to avoid unnecessary memory allocation. This caused a crash when copy_() was called on the None weight attribute in Stable_Zero123.__init__. Replace copy_() with direct torch.nn.Parameter assignment, which works correctly on both Windows (aimdo enabled) and other platforms. --- comfy/model_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..1c7695761 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -578,8 +578,8 @@ class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) - self.cc_projection.weight.copy_(cc_projection_weight) - self.cc_projection.bias.copy_(cc_projection_bias) + self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone()) + self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone()) def extra_conds(self, **kwargs): out = {} From 9949c19c632eb6cad50024e02816df86e7d41b27 Mon Sep 17 00:00:00 2001 From: blepping <157360029+blepping@users.noreply.github.com> Date: Wed, 22 Apr 2026 16:08:19 -0600 Subject: [PATCH 13/35] Derive InterruptProcessingException from BaseException (#13523) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..3b39d6080 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1801,7 +1801,7 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -class InterruptProcessingException(Exception): +class InterruptProcessingException(BaseException): pass interrupt_processing_mutex = threading.RLock() From cb388e2912f9d3adf50e3510ed1c470ad5c9bc79 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 23 Apr 2026 07:12:06 +0900 Subject: [PATCH 14/35] bump manager version to 4.2.1 (#13516) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index f770ec933..a079d3492 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1 +comfyui_manager==4.2.1 From ec4b1659ab751b7da07bfff8fa28660c7e82c00b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 23 Apr 2026 08:13:38 +1000 Subject: [PATCH 15/35] ModelPatcherDynamic: force cast stray weights on comfy layers (#13487) the mixed_precision ops can have input_scale parameters that are used in tensor math but arent a weight or bias so dont get proper VRAM management. Treat these as force-castable parameters like the non comfy weight, random params are buffers already are. --- comfy/model_patcher.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d19d6fe..ee56f8523 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -685,9 +685,9 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False): weight, set_func, convert_func = get_key_weight(self.model, key) - if key not in self.patches: + if key not in self.patches and not force_cast: return weight inplace_update = self.weight_inplace_update or inplace_update @@ -695,7 +695,7 @@ class ModelPatcher: if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) - temp_dtype = comfy.model_management.lora_compute_dtype(device_to) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: @@ -703,9 +703,10 @@ class ModelPatcher: if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) - out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if key in self.patches: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher): key = key_param_name_to_key(n, param_key) if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) - self.patch_weight_to_device(key, device_to=device_to) + self.patch_weight_to_device(key, device_to=device_to, force_cast=True) weight, _, _ = get_key_weight(self.model, key) if weight is not None: self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() @@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher): m._v = vbar.alloc(v_weight_size) allocated_size += v_weight_size + for param in params: + if param not in ("weight", "bias"): + force_load_param(self, param, device_to) + else: for param in params: key = key_param_name_to_key(n, param) From 0be87b082a68bca19ea25a9208120ba5090bea8d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:21:43 -0700 Subject: [PATCH 16/35] Update logging level for invalid version format (#13526) --- utils/install_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/install_util.py b/utils/install_util.py index 34489aec5..fdba23a8f 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -39,7 +39,7 @@ def get_required_packages_versions(): if len(s) == 2: version_str = s[-1] if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") + logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}") continue out[s[0]] = version_str return out.copy() From e988df72f8828085c1671d49f96ec50382f11c80 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 23 Apr 2026 03:59:55 +0300 Subject: [PATCH 17/35] [Partner Nodes] add SD2 real human support (#13509) * feat(api-nodes): add SD2 real human support Signed-off-by: bigcat88 * fix: add validation before uploading Assets Signed-off-by: bigcat88 * Add asset_id and group_id displaying on the node Signed-off-by: bigcat88 * extend poll_op to use instead of custom async cycle Signed-off-by: bigcat88 * added the polling for the "Active" status after asset creation Signed-off-by: bigcat88 * updated tooltip for group_id * allow usage of real human in the ByteDance2FirstLastFrame node * add reference count limits * corrected price in status when input assets contain video Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/bytedance.py | 35 +++ comfy_api_nodes/nodes_bytedance.py | 468 +++++++++++++++++++++++++++-- comfy_api_nodes/util/client.py | 9 +- 3 files changed, 494 insertions(+), 18 deletions(-) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index dc3bc3213..eafabbefe 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel): usage: TaskStatusUsage | None = Field(None) +class GetAssetResponse(BaseModel): + id: str = Field(...) + name: str | None = Field(None) + url: str | None = Field(None) + asset_type: str = Field(...) + group_id: str = Field(...) + status: str = Field(...) + error: TaskStatusError | None = Field(None) + + +class SeedanceCreateVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + h5_link: str = Field(...) + + +class SeedanceGetVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + status: str = Field(...) + group_id: str | None = Field(None) + error_code: str | None = Field(None) + error_message: str | None = Field(None) + + +class SeedanceCreateAssetRequest(BaseModel): + group_id: str = Field(...) + url: str = Field(...) + asset_type: str = Field(...) + name: str | None = Field(None, max_length=64) + project_name: str | None = Field(None) + + +class SeedanceCreateAssetResponse(BaseModel): + asset_id: str = Field(...) + + # Dollars per 1K tokens, keyed by (model_id, has_video_input). SEEDANCE2_PRICE_PER_1K_TOKENS = { ("dreamina-seedance-2-0-260128", False): 0.007, diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index bc564782d..de192c5ac 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,5 +1,6 @@ import logging import math +import re import torch from typing_extensions import override @@ -11,9 +12,14 @@ from comfy_api_nodes.apis.bytedance import ( SEEDANCE2_PRICE_PER_1K_TOKENS, SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, VIDEO_TASKS_EXECUTION_TIME, + GetAssetResponse, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedance2TaskCreationRequest, + SeedanceCreateAssetRequest, + SeedanceCreateAssetResponse, + SeedanceCreateVisualValidateSessionResponse, + SeedanceGetVisualValidateSessionResponse, Seedream4Options, Seedream4TaskCreationRequest, TaskAudioContent, @@ -44,10 +50,16 @@ from comfy_api_nodes.util import ( validate_image_aspect_ratio, validate_image_dimensions, validate_string, + validate_video_dimensions, + validate_video_duration, ) +from server import PromptServer BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +_VERIFICATION_POLL_TIMEOUT_SEC = 120 +_VERIFICATION_POLL_INTERVAL_SEC = 3 + SEEDREAM_MODELS = { "seedream 5.0 lite": "seedream-5-0-260128", "seedream-4-5-251128": "seedream-4-5-251128", @@ -96,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st ) +async def _resolve_reference_assets( + cls: type[IO.ComfyNode], + asset_ids: list[str], +) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: + """Look up each asset, validate Active status, group by asset_type. + + Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://". + """ + image_assets: dict[str, str] = {} + video_assets: dict[str, str] = {} + audio_assets: dict[str, str] = {} + for i, raw_id in enumerate(asset_ids, 1): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + result = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + ) + if result.status != "Active": + extra = f" {result.error.code}: {result.error.message}" if result.error else "" + raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}") + asset_uri = f"asset://{asset_id}" + if result.asset_type == "Image": + image_assets[asset_id] = asset_uri + elif result.asset_type == "Video": + video_assets[asset_id] = asset_uri + elif result.asset_type == "Audio": + audio_assets[asset_id] = asset_uri + return image_assets, video_assets, audio_assets + + +_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE) + + +def _build_asset_labels( + reference_assets: dict[str, str], + image_asset_uris: dict[str, str], + video_asset_uris: dict[str, str], + audio_asset_uris: dict[str, str], + n_reference_images: int, + n_reference_videos: int, + n_reference_audios: int, +) -> dict[int, str]: + """Map asset slot number (from 'asset_N' keys) to its positional label. + + Asset entries are appended to `content` after the reference_images/videos/audios, + so their 1-indexed labels continue from the count of existing same-type refs: + one reference_images entry + one Image-type asset -> asset labelled "Image 2". + """ + image_n = n_reference_images + video_n = n_reference_videos + audio_n = n_reference_audios + labels: dict[int, str] = {} + for slot_key, raw_id in reference_assets.items(): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + try: + slot_num = int(slot_key.rsplit("_", 1)[-1]) + except ValueError: + continue + if asset_id in image_asset_uris: + image_n += 1 + labels[slot_num] = f"Image {image_n}" + elif asset_id in video_asset_uris: + video_n += 1 + labels[slot_num] = f"Video {video_n}" + elif asset_id in audio_asset_uris: + audio_n += 1 + labels[slot_num] = f"Audio {audio_n}" + return labels + + +def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str: + """Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels.""" + if not labels: + return prompt + + def _sub(m: "re.Match[str]") -> str: + return labels.get(int(m.group(1)), m.group(0)) + + return _ASSET_REF_RE.sub(_sub, prompt) + + +async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str: + session = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"), + response_model=SeedanceCreateVisualValidateSessionResponse, + ) + logger.warning("Seedance authentication required. Open link: %s", session.h5_link) + + h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}" + + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"), + response_model=SeedanceGetVisualValidateSessionResponse, + status_extractor=lambda r: r.status, + completed_statuses=["completed"], + failed_statuses=["failed"], + poll_interval=_VERIFICATION_POLL_INTERVAL_SEC, + max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1, + estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1, + extra_text=h5_text, + ) + + if not result.group_id: + raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id") + + logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id) + PromptServer.instance.send_progress_text( + f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id + ) + return result.group_id + + +async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str: + if group_id and group_id.strip(): + return group_id.strip() + return await _obtain_group_id_via_h5_auth(cls) + + +async def _create_seedance_asset( + cls: type[IO.ComfyNode], + *, + group_id: str, + url: str, + name: str, + asset_type: str, +) -> str: + req = SeedanceCreateAssetRequest( + group_id=group_id, + url=url, + asset_type=asset_type, + name=name or None, + ) + result = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=req, + ) + return result.asset_id + + +async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse: + """Poll the newly created asset until its status becomes Active.""" + return await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + status_extractor=lambda r: r.status, + completed_statuses=["Active"], + failed_statuses=["Failed"], + poll_interval=5, + max_poll_attempts=1200, + extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}", + ) + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -1228,12 +1403,27 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): IO.Image.Input( "first_frame", tooltip="First frame image for the video.", + optional=True, ), IO.Image.Input( "last_frame", tooltip="Last frame image for the video.", optional=True, ), + IO.String.Input( + "first_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the first frame. " + "Mutually exclusive with the first_frame image input.", + optional=True, + ), + IO.String.Input( + "last_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the last frame. " + "Mutually exclusive with the last_frame image input.", + optional=True, + ), IO.Int.Input( "seed", default=0, @@ -1286,24 +1476,54 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): async def execute( cls, model: dict, - first_frame: Input.Image, seed: int, watermark: bool, + first_frame: Input.Image | None = None, last_frame: Input.Image | None = None, + first_frame_asset_id: str = "", + last_frame_asset_id: str = "", ) -> IO.NodeOutput: validate_string(model["prompt"], strip_whitespace=True, min_length=1) model_id = SEEDANCE_MODELS[model["model"]] + first_frame_asset_id = first_frame_asset_id.strip() + last_frame_asset_id = last_frame_asset_id.strip() + + if first_frame is not None and first_frame_asset_id: + raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.") + if first_frame is None and not first_frame_asset_id: + raise ValueError("Either first_frame or first_frame_asset_id is required.") + if last_frame is not None and last_frame_asset_id: + raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") + + asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] + image_assets: dict[str, str] = {} + if asset_ids_to_resolve: + image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve) + for aid in asset_ids_to_resolve: + if aid not in image_assets: + raise ValueError(f"Asset {aid} is not an Image asset.") + + if first_frame_asset_id: + first_frame_url = image_assets[first_frame_asset_id] + else: + first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") + content: list[TaskTextContent | TaskImageContent] = [ TaskTextContent(text=model["prompt"]), TaskImageContent( - image_url=TaskImageContentUrl( - url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") - ), + image_url=TaskImageContentUrl(url=first_frame_url), role="first_frame", ), ] - if last_frame is not None: + if last_frame_asset_id: + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]), + role="last_frame", + ), + ) + elif last_frame is not None: content.append( TaskImageContent( image_url=TaskImageContentUrl( @@ -1385,6 +1605,24 @@ def _seedance2_reference_inputs(resolutions: list[str]): tooltip="Automatically downscale reference videos that exceed the model's pixel budget " "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", ), + IO.Autogrow.Input( + "reference_assets", + template=IO.Autogrow.TemplateNames( + IO.String.Input("reference_asset"), + names=[ + "asset_1", + "asset_2", + "asset_3", + "asset_4", + "asset_5", + "asset_6", + "asset_7", + "asset_8", + "asset_9", + ], + min=0, + ), + ), ] @@ -1486,24 +1724,42 @@ class ByteDance2ReferenceNode(IO.ComfyNode): reference_images = model.get("reference_images", {}) reference_videos = model.get("reference_videos", {}) reference_audios = model.get("reference_audios", {}) + reference_assets = model.get("reference_assets", {}) - if not reference_images and not reference_videos: - raise ValueError("At least one reference image or video is required.") + reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets( + cls, list(reference_assets.values()) + ) + + if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets: + raise ValueError("At least one reference image or video or asset is required.") + + total_images = len(reference_images) + len(reference_image_assets) + if total_images > 9: + raise ValueError( + f"Too many reference images: {total_images} " + f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9." + ) + total_videos = len(reference_videos) + len(reference_video_assets) + if total_videos > 3: + raise ValueError( + f"Too many reference videos: {total_videos} " + f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3." + ) + total_audios = len(reference_audios) + len(reference_audio_assets) + if total_audios > 3: + raise ValueError( + f"Too many reference audios: {total_audios} " + f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3." + ) model_id = SEEDANCE_MODELS[model["model"]] - has_video_input = len(reference_videos) > 0 + has_video_input = total_videos > 0 if model.get("auto_downscale") and reference_videos: - max_px = ( - SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}) - .get(model["resolution"], {}) - .get("max") - ) + max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max") if max_px: for key in reference_videos: - reference_videos[key] = resize_video_to_pixel_budget( - reference_videos[key], max_px - ) + reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px) total_video_duration = 0.0 for i, key in enumerate(reference_videos, 1): @@ -1531,8 +1787,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode): if total_audio_duration > 15.1: raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.") + asset_labels = _build_asset_labels( + reference_assets, + reference_image_assets, + reference_video_assets, + reference_audio_assets, + len(reference_images), + len(reference_videos), + len(reference_audios), + ) + prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels) + content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [ - TaskTextContent(text=model["prompt"]), + TaskTextContent(text=prompt_text), ] for i, key in enumerate(reference_images, 1): content.append( @@ -1573,6 +1840,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode): ), ), ) + for url in reference_image_assets.values(): + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=url), + role="reference_image", + ), + ) + for url in reference_video_assets.values(): + content.append( + TaskVideoContent(video_url=TaskVideoContentUrl(url=url)), + ) + for url in reference_audio_assets.values(): + content.append( + TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)), + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), @@ -1627,6 +1909,156 @@ async def process_video_task( return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) +class ByteDanceCreateImageAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateImageAsset", + display_name="ByteDance Create Image Asset", + category="api node/image/ByteDance", + description=( + "Create a Seedance 2.0 personal image asset. Uploads the input image and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Image.Input("image", tooltip="Image to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000) + validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1)) + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_image_to_comfyapi(cls, image), + name="", + asset_type="Image", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + +class ByteDanceCreateVideoAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateVideoAsset", + display_name="ByteDance Create Video Asset", + category="api node/video/ByteDance", + description=( + "Create a Seedance 2.0 personal video asset. Uploads the input video and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Video.Input("video", tooltip="Video to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_video_duration(video, min_duration=2, max_duration=15) + validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000) + + w, h = video.get_dimensions() + if h > 0: + ratio = w / h + if not (0.4 <= ratio <= 2.5): + raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).") + pixels = w * h + if not (409_600 <= pixels <= 927_408): + raise ValueError( + f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})." + ) + + fps = float(video.get_frame_rate()) + if not (24 <= fps <= 60): + raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.") + + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_video_to_comfyapi(cls, video), + name="", + asset_type="Video", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1640,6 +2072,8 @@ class ByteDanceExtension(ComfyExtension): ByteDance2TextToVideoNode, ByteDance2FirstLastFrameNode, ByteDance2ReferenceNode, + ByteDanceCreateImageAsset, + ByteDanceCreateVideoAsset, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..b0cf97ae4 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -156,6 +156,7 @@ async def poll_op( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> M: raw = await poll_op_raw( cls, @@ -176,6 +177,7 @@ async def poll_op( estimated_duration=estimated_duration, cancel_endpoint=cancel_endpoint, cancel_timeout=cancel_timeout, + extra_text=extra_text, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -260,6 +262,7 @@ async def poll_op_raw( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> dict[str, Any]: """ Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, @@ -299,6 +302,7 @@ async def poll_op_raw( price=state.price, is_queued=state.is_queued, processing_elapsed_seconds=int(proc_elapsed), + extra_text=extra_text, ) await asyncio.sleep(1.0) except Exception as exc: @@ -389,6 +393,7 @@ async def poll_op_raw( price=state.price, is_queued=False, processing_elapsed_seconds=int(state.base_processing_elapsed), + extra_text=extra_text, ) return resp_json @@ -462,6 +467,7 @@ def _display_time_progress( price: float | None = None, is_queued: bool | None = None, processing_elapsed_seconds: int | None = None, + extra_text: str | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -469,7 +475,8 @@ def _display_time_progress( time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: time_line = f"Time elapsed: {int(elapsed_seconds)}s" - _display_text(node_cls, time_line, status=status, price=price) + text = f"{time_line}\n\n{extra_text}" if extra_text else time_line + _display_text(node_cls, text, status=status, price=price) async def _diagnose_connectivity() -> dict[str, bool]: From 749d5b4e8d4308c67fee6faa4ef4dfbde23087f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 23 Apr 2026 07:07:43 +0300 Subject: [PATCH 18/35] feat: SAM (segment anything) 3.1 support (CORE-34) (#13408) --- comfy/ldm/sam3/detector.py | 596 ++++++++++ comfy/ldm/sam3/sam.py | 425 +++++++ comfy/ldm/sam3/tracker.py | 1785 ++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 12 + comfy/supported_models.py | 53 +- comfy/text_encoders/sam3_clip.py | 97 ++ comfy_extras/nodes_sam3.py | 529 +++++++++ nodes.py | 1 + 9 files changed, 3502 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/sam3/detector.py create mode 100644 comfy/ldm/sam3/sam.py create mode 100644 comfy/ldm/sam3/tracker.py create mode 100644 comfy/text_encoders/sam3_clip.py create mode 100644 comfy_extras/nodes_sam3.py diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py new file mode 100644 index 000000000..6ae919a79 --- /dev/null +++ b/comfy/ldm/sam3/detector.py @@ -0,0 +1,596 @@ +# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker +from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__ +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine + +TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker} +from comfy.ops import cast_to_input + + +def box_cxcywh_to_xyxy(x): + cx, cy, w, h = x.unbind(-1) + return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1) + + +def gen_sineembed_for_position(pos_tensor, num_feats=256): + """Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats).""" + assert num_feats % 2 == 0 + hdim = num_feats // 2 + freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim) + embeds = [] + for c in range(pos_tensor.shape[-1]): + raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs + embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2)) + return torch.cat(embeds, dim=-1).to(pos_tensor.dtype) + + +class SplitMHA(nn.Module): + """Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight).""" + def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def forward(self, q_input, k_input=None, v_input=None, mask=None): + q = self.q_proj(q_input) + if k_input is None: + k = self.k_proj(q_input) + v = self.v_proj(q_input) + else: + k = self.k_proj(k_input) + v = self.v_proj(v_input if v_input is not None else k_input) + if mask is not None and mask.ndim == 2: + mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast + dtype = q.dtype # manual_cast may produce mixed dtypes + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) + return self.out_proj(out) + + +class MLPWithNorm(nn.Module): + """MLP with residual connection and output LayerNorm.""" + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([ + operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) + for i in range(num_layers) + ]) + self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype) + self.residual = residual and (input_dim == output_dim) + + def forward(self, x): + orig = x + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + x = F.relu(x) + if self.residual: + x = x + orig + return self.out_norm(x) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, pos, text_memory=None, text_mask=None): + normed = self.norm1(x) + q_k = normed + pos + x = x + self.self_attn(q_k, q_k, normed) + if text_memory is not None: + normed = self.norm2(x) + x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class TransformerEncoder(nn.Module): + """Checkpoint: transformer.encoder.layers.N.*""" + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + + def forward(self, x, pos, text_memory=None, text_mask=None): + for layer in self.layers: + x = layer(x, pos, text_memory, text_mask) + return x + + +class DecoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + + def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None): + q_k = x + x_pos + x = self.norm2(x + self.self_attn(q_k, q_k, x)) + if text_memory is not None: + x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask)) + x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias)) + x = self.norm3(x + self.linear2(F.relu(self.linear1(x)))) + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.num_queries = num_queries + + self.layers = nn.ModuleList([ + DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype) + self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes + self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each) + self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations) + + self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + + self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + + @staticmethod + def _inverse_sigmoid(x): + return torch.log(x / (1 - x + 1e-6) + 1e-6) + + def _compute_box_rpb(self, ref_points, H, W): + """Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias.""" + boxes_xyxy = box_cxcywh_to_xyxy(ref_points) + B, Q, _ = boxes_xyxy.shape + coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H + coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W + deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2] + deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2] + + log2_8 = float(math.log2(8)) + def log_scale(d): + return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8 + + rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype)) + rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype)) + + bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2) + pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype) + return torch.cat([pres_bias, bias], dim=2) + + def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72): + B = memory.shape[0] + tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1) + presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1) + ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid() + + for layer_idx, layer in enumerate(self.layers): + query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model)) + tgt_with_pres = torch.cat([presence_out, tgt], dim=1) + pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1) + tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos, + text_memory, text_mask, self._compute_box_rpb(ref_points, H, W)) + presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:] + if layer_idx < len(self.layers) - 1: + ref_inv = self._inverse_sigmoid(ref_points) + ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach() + + query_out = self.norm(tgt) + ref_inv = self._inverse_sigmoid(ref_points) + boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid() + presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1) + return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence} + + +class Transformer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations) + self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations) + + +class GeometryEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.roi_size = roi_size + self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype) + self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype) + self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype) + self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype) + self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype) + self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.encode = nn.ModuleList([ + EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def _encode_points(self, coords, labels, img_feat_2d): + """Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized.""" + B, N, _ = coords.shape + embed = self.points_direct_project(coords) + # Pool features from backbone at point locations via grid_sample + grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1] + sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1] + embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C] + # Positional encoding of coordinates + x, y = coords[:, :, 0], coords[:, :, 1] # [B, N] + pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten()) + enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1) + embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def _encode_boxes(self, boxes, labels, img_feat_2d): + """Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh.""" + B, N, _ = boxes.shape + embed = self.boxes_direct_project(boxes) + # ROI align from backbone at box regions + H, W = img_feat_2d.shape[-2:] + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) + boxes_scaled = boxes_xyxy * scale + sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size) + proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C] + embed = embed + proj + # Positional encoding of box center + size + cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3] + enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten()) + enc = enc.view(B, N, -1) + embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def forward(self, points=None, boxes=None, image_features=None): + """Encode geometry prompts. image_features: [B, HW, C] flattened backbone features.""" + # Prepare 2D image features for pooling + img_feat_2d = None + if image_features is not None: + B = image_features.shape[0] + HW, C = image_features.shape[1], image_features.shape[2] + hw = int(math.sqrt(HW)) + img_normed = self.img_pre_norm(image_features) + img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw) + + embeddings = [] + if points is not None: + coords, labels = points + embeddings.append(self._encode_points(coords, labels, img_feat_2d)) + if boxes is not None: + B = boxes.shape[0] + box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device) + embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d)) + if not embeddings: + return None + geo = torch.cat(embeddings, dim=1) + geo = self.norm(geo) + if image_features is not None: + for layer in self.encode: + geo = layer(geo, torch.zeros_like(geo), image_features) + geo = self.encode_norm(geo) + return self.final_proj(geo) + + +class PixelDecoder(nn.Module): + """Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation.""" + def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)]) + self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)]) + + def forward(self, backbone_features): + prev = backbone_features[-1] + for i, feat in enumerate(backbone_features[:-1][::-1]): + prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest")))) + return prev + + +class MaskPredictor(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, query_embeddings, pixel_features): + mask_embed = self.mask_embed(query_embeddings) + return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features) + + +class SegmentationHead(nn.Module): + def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations) + self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations) + self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype) + + def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None): + if encoder_hidden_states is not None and prompt is not None: + enc_normed = self.cross_attn_norm(encoder_hidden_states) + enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask) + encoder_hidden_states = enc_cross + encoder_hidden_states + + if encoder_hidden_states is not None: + B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1] + encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W) + backbone_features = list(backbone_features) + backbone_features[-1] = encoder_visual + + pixel_features = self.pixel_decoder(backbone_features) + instance_features = self.instance_seg_head(pixel_features) + masks = self.mask_predictor(query_embeddings, instance_features) + return masks + + +class DotProductScoring(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations) + self.scale = 1.0 / (d_model ** 0.5) + + def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None): + prompt = self.prompt_mlp(prompt_embeddings) + if prompt_mask is not None: + weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype) + pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1) + else: + pooled = prompt.mean(dim=1) + hs = self.hs_proj(query_embeddings) + pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype) + scores = torch.matmul(hs, pp) + return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1) + + +class SAM3Detector(nn.Module): + def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + image_model = kwargs.pop("image_model", "SAM3") + for k in ("num_heads", "num_head_channels"): + kwargs.pop(k, None) + multiplex = image_model == "SAM31" + # SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0) + self.scalp = 0 if multiplex else 1 + self.backbone = nn.ModuleDict({ + "vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs), + "language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}), + }) + self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations) + self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations) + + def _get_backbone_features(self, images): + """Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions).""" + bb = self.backbone["vision_backbone"] + if bb.multiplex: + all_f, all_p, tf, tp = bb(images, tracker_mode="propagation") + else: + all_f, all_p, tf, tp = bb(images, need_tracker=True) + return all_f, all_p, tf, tp + + @staticmethod + def _run_geo_layer(layer, x, memory, memory_pos): + x = x + layer.self_attn(layer.norm1(x)) + x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory) + x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x)))) + return x + + def _detect(self, features, positions, text_embeddings=None, text_mask=None, + points=None, boxes=None): + """Shared detection: geometry encoding, transformer, scoring, segmentation.""" + B = features[0].shape[0] + # Scalp for encoder (use top-level feature), but keep all levels for segmentation head + seg_features = features + if self.scalp > 0: + features = features[:-self.scalp] + positions = positions[:-self.scalp] + enc_feat, enc_pos = features[-1], positions[-1] + _, _, H, W = enc_feat.shape + img_flat = enc_feat.flatten(2).permute(0, 2, 1) + pos_flat = enc_pos.flatten(2).permute(0, 2, 1) + + has_prompts = text_embeddings is not None or points is not None or boxes is not None + if has_prompts: + geo_enc = self.geometry_encoder + geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat) + geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1))) + for layer in geo_enc.encode: + geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat) + geo_cls = geo_enc.encode_norm(geo_cls) + if text_embeddings is not None and text_embeddings.shape[0] != B: + text_embeddings = text_embeddings.expand(B, -1, -1) + if text_mask is not None and text_mask.shape[0] != B: + text_mask = text_mask.expand(B, -1) + parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None] + text_embeddings = torch.cat(parts, dim=1) + n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0) + if text_mask is not None: + text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1) + else: + text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device) + + memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask) + dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W) + query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"] + + if text_embeddings is not None: + scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask) + else: + scores = torch.zeros(B, query_out.shape[1], device=query_out.device) + + masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask) + return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out + + def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None): + features, positions, _, _ = self._get_backbone_features(images) + + if text_embeddings is not None: + text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings) + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, dec_out = self._detect( + features, positions, text_embeddings, text_mask, points, boxes) + + if orig_size is not None: + oh, ow = orig_size + boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype) + masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False) + + return { + "boxes": boxes_xyxy, + "scores": scores, + "masks": masks, + "presence": dec_out.get("presence"), + } + + def forward_from_trunk(self, trunk_out, text_embeddings, text_mask): + """Run detection using a pre-computed ViTDet trunk output. + + text_embeddings must already be resized through language_backbone.resizer. + Returns dict with boxes (normalized xyxy), scores, masks at detector resolution. + """ + bb = self.backbone["vision_backbone"] + features = [conv(trunk_out) for conv in bb.convs] + positions = [cast_to_input(bb.position_encoding(f), f) for f in features] + + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask) + return {"boxes": boxes_xyxy, "scores": scores, "masks": masks} + + +class SAM3Model(nn.Module): + def __init__(self, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + image_model = kwargs.get("image_model", "SAM3") + tracker_cls = TRACKER_CLASSES[image_model] + self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs) + self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs) + + def forward(self, images, **kwargs): + return self.detector(images, **kwargs) + + def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None): + """Interactive segmentation using SAM decoder with point/box/mask prompts. + + Args: + images: [B, 3, 1008, 1008] preprocessed images + point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space + box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space + mask_inputs: [B, 1, H, W] coarse mask logits to refine + Returns: + [B, 1, image_size, image_size] high-res mask logits + """ + bb = self.detector.backbone["vision_backbone"] + if bb.multiplex: + _, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive") + else: + _, _, tracker_features, tracker_positions = bb(images, need_tracker=True) + if self.detector.scalp > 0: + tracker_features = tracker_features[:-self.detector.scalp] + tracker_positions = tracker_positions[:-self.detector.scalp] + + high_res = list(tracker_features[:-1]) + backbone_feat = tracker_features[-1] + B, C, H, W = backbone_feat.shape + # Add no-memory embedding (init frame path) + no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None) + if no_mem is None: + no_mem = getattr(self.tracker, 'no_mem_embed', None) + if no_mem is not None: + feat_flat = backbone_feat.flatten(2).permute(0, 2, 1) + feat_flat = feat_flat + cast_to_input(no_mem, feat_flat) + backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2) + + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + _, high_res_masks, _, _ = self.tracker._forward_sam_heads( + backbone_features=backbone_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + box_inputs=box_inputs, + high_res_features=high_res, + multimask_output=(0 < num_pts <= 1), + ) + return high_res_masks + + def forward_video(self, images, initial_masks, pbar=None, text_prompts=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1): + """Track video with optional per-frame text-prompted detection.""" + bb = self.detector.backbone["vision_backbone"] + + def backbone_fn(frame, frame_idx=None): + trunk_out = bb.trunk(frame) + if bb.multiplex: + _, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True) + else: + _, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True) + return tf, tp, trunk_out + + detect_fn = None + if text_prompts: + resizer = self.detector.backbone["language_backbone"]["resizer"] + resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts] + def detect_fn(trunk_out): + all_scores, all_masks = [], [] + for emb, mask in resized: + det = self.detector.forward_from_trunk(trunk_out, emb, mask) + all_scores.append(det["scores"]) + all_masks.append(det["masks"]) + return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)} + + if hasattr(self.tracker, 'track_video_with_detection'): + return self.tracker.track_video_with_detection( + backbone_fn, images, initial_masks, detect_fn, + new_det_thresh=new_det_thresh, max_objects=max_objects, + detect_interval=detect_interval, backbone_obj=bb, pbar=pbar) + # SAM3 (non-multiplex) — no detection support, requires initial masks + if initial_masks is None: + raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking") + return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py new file mode 100644 index 000000000..272781d45 --- /dev/null +++ b/comfy/ldm/sam3/sam.py @@ -0,0 +1,425 @@ +# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.layers import EmbedND +from comfy.ops import cast_to_input + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)]) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x) + return torch.sigmoid(x) if self.sigmoid_output else x + + +class SAMAttention(nn.Module): + def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + internal_dim = embedding_dim // downsample_rate + kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype) + + def forward(self, q, k, v): + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + return self.out_proj(optimized_attention(q, k, v, self.num_heads)) + + +class TwoWayAttentionBlock(nn.Module): + def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None): + super().__init__() + self.skip_first_layer_pe = skip_first_layer_pe + self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype)) + self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, queries, keys, query_pe, key_pe): + if self.skip_first_layer_pe: + queries = self.norm1(self.self_attn(queries, queries, queries)) + else: + q = queries + query_pe + queries = self.norm1(queries + self.self_attn(q, q, queries)) + q, k = queries + query_pe, keys + key_pe + queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys)) + queries = self.norm3(queries + self.mlp(queries)) + q, k = queries + query_pe, keys + key_pe + keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries)) + return queries, keys + + +class TwoWayTransformer(nn.Module): + def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate, + skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations) + for i in range(depth) + ]) + self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, image_embedding, image_pe, point_embedding): + queries, keys = point_embedding, image_embedding + for layer in self.layers: + queries, keys = layer(queries, keys, point_embedding, image_pe) + q, k = queries + point_embedding, keys + image_pe + queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys)) + return queries, keys + + +class PositionEmbeddingRandom(nn.Module): + """Fourier feature positional encoding with random gaussian projection.""" + def __init__(self, num_pos_feats=64, scale=None): + super().__init__() + self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats)) + + def _encode(self, normalized_coords): + """Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32.""" + orig_dtype = normalized_coords.dtype + proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32) + projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix + return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype) + + def forward(self, size, device=None): + h, w = size + dev = device if device is not None else self.positional_encoding_gaussian_matrix.device + ones = torch.ones((h, w), device=dev, dtype=torch.float32) + norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1) + return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0) + + def forward_with_coords(self, pixel_coords, image_size): + norm = pixel_coords.clone() + norm[:, :, 0] /= image_size[1] + norm[:, :, 1] /= image_size[0] + return self._encode(norm) + + +# ViTDet backbone + FPN neck + +def window_partition(x: torch.Tensor, window_size: int): + B, H, W, C = x.shape + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw): + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0): + """Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2].""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + ids = torch.stack([(t % end_x) * scale_pos, + torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1) + return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0)) + + +class _ViTMLP(nn.Module): + def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None): + super().__init__() + hidden = int(dim * mlp_ratio) + self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype) + self.act = nn.GELU() + self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class Attention(nn.Module): + """ViTDet multi-head attention with fused QKV projection.""" + + def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.use_rope = use_rope + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x, freqs_cis=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) + if self.use_rope and freqs_cis is not None: + q, k = apply_rope(q, k, freqs_cis) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.window_size = window_size + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations) + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + + def forward(self, x, freqs_cis=None): + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + x = x.view(x.shape[0], self.window_size * self.window_size, -1) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(-1, self.window_size, self.window_size, x.shape[-1]) + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + else: + B, H, W, C = x.shape + x = x.view(B, H * W, C) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(B, H, W, C) + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None): + super().__init__() + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.proj(x) + + +class ViTDet(nn.Module): + def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24, + global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_heads = num_heads + self.global_att_blocks = set(global_att_blocks) + + self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations) + + num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype)) + + self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + grid_size = img_size // patch_size + pretrain_grid = pretrain_img_size // patch_size + + self.blocks = nn.ModuleList() + for i in range(depth): + is_global = i in self.global_att_blocks + self.blocks.append(Block( + embed_dim, num_heads, mlp_ratio, qkv_bias, + window_size=0 if is_global else window_size, + use_rope=use_rope, + device=device, dtype=dtype, operations=operations, + )) + + if use_rope: + rope_scale = pretrain_grid / grid_size + self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False) + self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False) + else: + self.freqs_cis = None + self.freqs_cis_window = None + + def _get_pos_embed(self, num_tokens): + pos = self.pos_embed + if pos.shape[1] == num_tokens: + return pos + cls_pos = pos[:, :1] + spatial_pos = pos[:, 1:] + old_size = int(math.sqrt(spatial_pos.shape[1])) + new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size + spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2) + tiles_h = new_size // old_size + 1 + tiles_w = new_size // old_size + 1 + tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size] + tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1) + return torch.cat([cls_pos, tiled], dim=1) + + def forward(self, x): + x = self.patch_embed(x) + B, C, Hp, Wp = x.shape + x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C) + + pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x) + x = x + pos[:, 1:Hp * Wp + 1] + + x = x.view(B, Hp, Wp, C) + x = self.ln_pre(x) + + freqs_cis_global = self.freqs_cis + freqs_cis_win = self.freqs_cis_window + if freqs_cis_global is not None: + freqs_cis_global = cast_to_input(freqs_cis_global, x) + if freqs_cis_win is not None: + freqs_cis_win = cast_to_input(freqs_cis_win, x) + + for block in self.blocks: + fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global + x = block(x, freqs_cis=fc) + + return x.permute(0, 3, 1, 2) + + +class FPNScaleConv(nn.Module): + def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None): + super().__init__() + if scale == 4.0: + self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 4 + elif scale == 2.0: + self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 2 + elif scale == 1.0: + proj_in = in_dim + elif scale == 0.5: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + proj_in = in_dim + self.scale = scale + self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype) + self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype) + + def forward(self, x): + if self.scale == 4.0: + x = F.gelu(self.dconv_2x2_0(x)) + x = self.dconv_2x2_1(x) + elif self.scale == 2.0: + x = self.dconv_2x2(x) + elif self.scale == 0.5: + x = self.pool(x) + x = self.conv_1x1(x) + x = self.conv_3x3(x) + return x + + +class PositionEmbeddingSine(nn.Module): + """2D sinusoidal position encoding (DETR-style) with result caching.""" + def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None): + super().__init__() + assert num_pos_feats % 2 == 0 + self.half_dim = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + self.scale = scale if scale is not None else 2 * math.pi + self._cache = {} + + def _sincos(self, vals): + """Encode 1D values to interleaved sin/cos features.""" + freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim) + raw = vals[..., None] * self.scale / freqs + return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2) + + def _encode_xy(self, x, y): + """Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim].""" + dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim) + pos_x = x[:, None] * self.scale / dim_t + pos_y = y[:, None] * self.scale / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + def encode_boxes(self, cx, cy, w, h): + """Encode box center + size to [N, d_model+2] features.""" + pos_x, pos_y = self._encode_xy(cx, cy) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + def forward(self, x): + B, C, H, W = x.shape + key = (H, W, x.device) + if key not in self._cache: + gy = torch.arange(H, dtype=torch.float32, device=x.device) + gx = torch.arange(W, dtype=torch.float32, device=x.device) + if self.normalize: + gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6) + yy, xx = torch.meshgrid(gy, gx, indexing="ij") + self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0) + return self._cache[key].expand(B, -1, -1, -1) + + +class SAM3VisionBackbone(nn.Module): + def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.multiplex = multiplex + + fpn_args = dict(device=device, dtype=dtype, operations=operations) + if multiplex: + scales = [4.0, 2.0, 1.0] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + else: + scales = [4.0, 2.0, 1.0, 0.5] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + + def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False): + backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images) + + if tracker_only: + # Skip detector FPN when only tracker features are needed (video tracking) + if self.multiplex: + tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs + else: + tracker_convs = self.sam2_convs + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return None, None, tracker_features, tracker_positions + + features = [conv(backbone_out) for conv in self.convs] + positions = [cast_to_input(self.position_encoding(f), f) for f in features] + + if self.multiplex: + if tracker_mode == "propagation": + tracker_convs = self.propagation_convs + elif tracker_mode == "interactive": + tracker_convs = self.interactive_convs + else: + return features, positions, None, None + elif need_tracker: + tracker_convs = self.sam2_convs + else: + return features, positions, None, None + + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return features, positions, tracker_features, tracker_positions diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py new file mode 100644 index 000000000..6ff6369d1 --- /dev/null +++ b/comfy/ldm/sam3/tracker.py @@ -0,0 +1,1785 @@ +# SAM3 video tracker: memory encoder, memory attention, SAM mask decoder/prompt encoder. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +try: + import cv2 + _HAS_CV2 = True +except ImportError: + from scipy import ndimage + _HAS_CV2 = False + +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.sam import rope_2d, PositionEmbeddingSine +from comfy.ops import cast_to_input +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.cascade.common import LayerNorm2d_op +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingRandom +from comfy.ldm.sam3.sam import TwoWayTransformer as SAMTwoWayTransformer + +NO_OBJ_SCORE = -1024.0 + + +def to_spatial(x, H, W): + """Reshape (B, H*W, C) → (B, C, H, W).""" + return x.view(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + +class MultiplexState: + """Tracks object-to-slot assignments for multiplex tracking. Provides mux/demux operations.""" + + def __init__(self, num_objects, multiplex_count, device, dtype): + self.multiplex_count = multiplex_count + self.device = device + self.dtype = dtype + self._build(num_objects) + + def mux(self, x): + """[N_obj, ...] -> [num_buckets, multiplex_count, ...]""" + out_shape = (self.num_buckets, self.multiplex_count) + x.shape[1:] + return (self.mux_matrix.to(device=x.device, dtype=x.dtype) @ x.reshape(self.total_valid_entries, -1)).view(out_shape) + + def demux(self, x): + """[num_buckets, multiplex_count, ...] -> [N_obj, ...]""" + out_shape = (self.total_valid_entries,) + x.shape[2:] + flat = x.reshape(self.num_buckets * self.multiplex_count, -1) + return (self.demux_matrix.to(device=x.device, dtype=x.dtype) @ flat).view(out_shape) + + def get_valid_object_mask(self): + """[num_buckets, multiplex_count] bool tensor, True for valid slots.""" + return (self.mux_matrix.sum(dim=1) > 0).reshape(self.num_buckets, self.multiplex_count) + + def _build(self, num_objects): + M = self.multiplex_count + self.num_buckets = (num_objects + M - 1) // M + self.total_valid_entries = num_objects + total_slots = self.num_buckets * M + self.mux_matrix = torch.zeros(total_slots, num_objects, device=self.device, dtype=self.dtype) + self.demux_matrix = torch.zeros(num_objects, total_slots, device=self.device, dtype=self.dtype) + oids = torch.arange(num_objects, device=self.device) + slots = (oids // M) * M + (oids % M) + self.mux_matrix[slots, oids] = 1.0 + self.demux_matrix[oids, slots] = 1.0 + + def add_objects(self, n_new): + """Grow multiplex state for n_new additional objects.""" + self._build(self.total_valid_entries + n_new) + +def _compute_mask_overlap(masks_a, masks_b): + """Max of IoU and IoM (intersection over minimum area). More robust to size differences.""" + a_flat = (masks_a > 0).float().flatten(1) + b_flat = (masks_b > 0).float().flatten(1) + intersection = a_flat @ b_flat.T + area_a = a_flat.sum(1, keepdim=True) + area_b = b_flat.sum(1, keepdim=True).T + iou = intersection / (area_a + area_b - intersection).clamp(min=1) + iom = intersection / torch.min(area_a.expand_as(iou), area_b.expand_as(iou)).clamp(min=1) + return torch.max(iou, iom) + + +def _nms_masks(masks, scores, thresh=0.5): + """Mask-based NMS using IoU+IoM overlap. Returns (filtered_masks, filtered_scores).""" + order = scores.argsort(descending=True) + masks, scores = masks[order], scores[order] + keep = [] + for i in range(masks.shape[0]): + if keep: + if _compute_mask_overlap(masks[i:i+1], masks[torch.tensor(keep, device=masks.device)]).max() >= thresh: + continue + keep.append(i) + return masks[keep], scores[keep] + + +def _get_connected_components(mask_bin): + """Get connected component labels and areas. mask_bin: [B, 1, H, W] uint8.""" + labels_list, areas_list = [], [] + for i in range(mask_bin.shape[0]): + m = mask_bin[i, 0].cpu().numpy() + if _HAS_CV2: + _, labeled, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) + areas = stats[labeled, cv2.CC_STAT_AREA].astype('int32') + else: + labeled, num_features = ndimage.label(m) + areas = np.zeros_like(m, dtype=np.int32) + for c in range(1, num_features + 1): + component = labeled == c + areas[component] = component.sum() + labels_list.append(torch.from_numpy(labeled).to(mask_bin.device)) + areas_list.append(torch.from_numpy(areas).to(device=mask_bin.device, dtype=torch.int32)) + return torch.stack(labels_list).unsqueeze(1), torch.stack(areas_list).unsqueeze(1) + + +def fill_holes_in_mask_scores(mask, max_area=0): + """Remove small foreground sprinkles and fill small background holes using connected components.""" + if max_area <= 0: + return mask + + # Fill holes: small connected components in background → foreground + mask_bg = (mask <= 0).to(torch.uint8) + _, areas_bg = _get_connected_components(mask_bg) + small_bg = mask_bg.bool() & (areas_bg <= max_area) + mask = torch.where(small_bg, 0.1, mask) + + # Remove sprinkles: small connected components in foreground → background + # Only remove if area < min(max_area, half of total foreground area) + mask_fg = (mask > 0).to(torch.uint8) + fg_area_thresh = mask_fg.sum(dim=(2, 3), keepdim=True, dtype=torch.int32) + fg_area_thresh.floor_divide_(2).clamp_(max=max_area) + _, areas_fg = _get_connected_components(mask_fg) + small_fg = mask_fg.bool() & (areas_fg <= fg_area_thresh) + mask = torch.where(small_fg, -0.1, mask) + + return mask + + +def apply_rope_memory(q, k, freqs, num_heads, num_k_exclude_rope=0): + """Apply 2D axial RoPE to memory attention using flux rope format. + + Args: + q: [B, Nq, C] projected queries (current frame features) + k: [B, Nk, C] projected keys (memory tokens) + freqs: [1, Nq, dim//2, 2, 2] flux-format rotation matrices for one frame + num_heads: number of attention heads + num_k_exclude_rope: number of trailing k tokens to skip RoPE (object pointers) + """ + B, Nq, C = q.shape + head_dim = C // num_heads + + # freqs shape: [1, 1, Nq, dim//2, 2, 2] (heads broadcast dim already included) + q_h = q.view(B, Nq, num_heads, head_dim).transpose(1, 2) + q_h = apply_rope1(q_h, freqs) + q = q_h.transpose(1, 2).reshape(B, Nq, C) + + # Apply RoPE to k (excluding last num_k_exclude_rope tokens) + Nk = k.shape[1] + num_k_rope = Nk - num_k_exclude_rope + if num_k_rope > 0: + # Repeat freqs for multiple frames of spatial memory + Nf = freqs.shape[2] # spatial positions in one frame + if num_k_rope > Nf: + r = (num_k_rope + Nf - 1) // Nf + pe_k = freqs.repeat(1, 1, r, 1, 1, 1)[:, :, :num_k_rope] + else: + pe_k = freqs[:, :, :num_k_rope] + + k_h = k[:, :num_k_rope].view(B, num_k_rope, num_heads, head_dim).transpose(1, 2) + k_h = apply_rope1(k_h, pe_k) + k = k.clone() + k[:, :num_k_rope] = k_h.transpose(1, 2).reshape(B, num_k_rope, C) + + return q, k + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """1D sinusoidal positional encoding for temporal positions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + pos_embed = pos_inds.unsqueeze(-1) / dim_t + return torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + + +def _pad_to_buckets(tensor, target_buckets): + """Pad a [num_buckets, ...] tensor to target_buckets along dim 0 if needed.""" + if tensor.shape[0] >= target_buckets: + return tensor + pad_shape = (target_buckets - tensor.shape[0],) + tensor.shape[1:] + return torch.cat([tensor, torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)], dim=0) + + +def pack_masks(masks): + """Pack binary masks [*, H, W] to bit-packed [*, H, W//8] uint8. W must be divisible by 8.""" + binary = masks > 0 + shifts = torch.arange(8, device=masks.device) + return (binary.view(*masks.shape[:-1], -1, 8) * (1 << shifts)).sum(-1).byte() + + +def unpack_masks(packed): + """Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8].""" + shifts = torch.arange(8, device=packed.device) + return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool() + + +def _compute_backbone(backbone_fn, frame, frame_idx=None): + """Compute backbone features for a single frame. Returns (vision_feats, vision_pos, feat_sizes, features, trunk_out).""" + features, positions, trunk_out = backbone_fn(frame, frame_idx=frame_idx) + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in features] + vision_feats = [x.flatten(2).permute(0, 2, 1) for x in features] + vision_pos = [x.flatten(2).permute(0, 2, 1) for x in positions] + return vision_feats, vision_pos, feat_sizes, features, trunk_out + + +def collect_memory_tokens(output_dict, frame_idx, num_maskmem, maskmem_tpos_enc, device, + collect_image_feats=False, tpos_v2=False, num_buckets=None): + """Collect spatial memory, position encodings, and optionally image features from past frames.""" + to_cat_memory, to_cat_memory_pos = [], [] + to_cat_image_feat, to_cat_image_pos = [], [] + + def _append(out, tpos_idx): + feats = out["maskmem_features"].to(device) + if num_buckets is not None: + feats = _pad_to_buckets(feats, num_buckets) + to_cat_memory.append(feats.flatten(2).permute(0, 2, 1)) + enc = out["maskmem_pos_enc"][-1].to(device).flatten(2).permute(0, 2, 1) + if num_buckets is not None: + enc = _pad_to_buckets(enc, num_buckets) + tpos = cast_to_input(maskmem_tpos_enc[tpos_idx], enc) + to_cat_memory_pos.append(enc + tpos) + if collect_image_feats and "image_features" in out: + to_cat_image_feat.append(out["image_features"].to(device)) + to_cat_image_pos.append(out["image_pos_enc"].to(device) + tpos) + + cond_outputs = output_dict["cond_frame_outputs"] + for t, out in cond_outputs.items(): + if tpos_v2: + t_pos = frame_idx - t + tpos_idx = num_maskmem - t_pos - 1 if 0 < t_pos < num_maskmem else num_maskmem - 1 + else: + tpos_idx = num_maskmem - 1 + _append(out, tpos_idx) + + for t_pos in range(1, num_maskmem): + out = output_dict["non_cond_frame_outputs"].get(frame_idx - (num_maskmem - t_pos), None) + if out is None or out.get("maskmem_features") is None: + continue + _append(out, num_maskmem - t_pos - 1) + + return to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs + + +def compute_tpos_enc(rel_pos_list, device, d_model, proj_layer, dtype=None, max_abs_pos=None): + """Temporal position encoding for object pointers.""" + pos_enc = torch.tensor(rel_pos_list, dtype=torch.float32, device=device) / max((max_abs_pos or 2) - 1, 1) + pos_enc = get_1d_sine_pe(pos_enc, dim=d_model) + if dtype is not None: + pos_enc = pos_enc.to(dtype) + return proj_layer(pos_enc) + + +def forward_sam_heads(backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + """Shared SAM prompt encoder + mask decoder forward for both SAM3 and SAM3.1 trackers.""" + device = backbone_features.device + # Batch size from inputs (mask_inputs may have N_obj > 1 while backbone is batch 1) + if mask_inputs is not None: + B = mask_inputs.shape[0] + elif box_inputs is not None: + B = box_inputs.shape[0] + elif point_inputs is not None: + B = point_inputs["point_coords"].shape[0] + else: + B = backbone_features.shape[0] + + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + else: + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + if mask_inputs is not None: + prompt_size = (prompt_encoder.image_embedding_size[0] * 4, prompt_encoder.image_embedding_size[1] * 4) + if mask_inputs.shape[-2:] != prompt_size: + sam_mask_prompt = F.interpolate(mask_inputs, size=prompt_size, mode="bilinear", align_corners=False, antialias=True) + else: + sam_mask_prompt = mask_inputs + else: + sam_mask_prompt = None + + sparse, dense = prompt_encoder(points=(sam_point_coords, sam_point_labels), boxes=box_inputs, masks=sam_mask_prompt) + sparse = cast_to_input(sparse, backbone_features) + dense = cast_to_input(dense, backbone_features) + image_pe = cast_to_input(prompt_encoder.get_dense_pe(), backbone_features) + + low_res_multimasks, ious, sam_output_tokens, object_score_logits = mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, + high_res_features=high_res_features, multimask_output=multimask_output, return_all=True, + ) + + is_obj_appearing = object_score_logits > 0 + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_multimasks.dtype)) + high_res_multimasks = F.interpolate(low_res_multimasks, size=(image_size, image_size), mode="bilinear", align_corners=False) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + obj_ptr = obj_ptr_proj(sam_output_token) + obj_ptr = no_obj_fn(obj_ptr, is_obj_appearing) + + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +def use_mask_as_output(backbone_features, high_res_features, mask_inputs, mask_downsample, + prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, image_size, backbone_stride): + """Shared mask-as-output for both SAM3 and SAM3.1 trackers.""" + out_scale, out_bias = 20.0, -10.0 + mask_inputs_float = cast_to_input(mask_inputs, backbone_features) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate(high_res_masks, size=(image_size // backbone_stride * 4,) * 2, + mode="bilinear", align_corners=False, antialias=True) + _, _, obj_ptr, _ = forward_sam_heads( + backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, mask_inputs=mask_downsample(mask_inputs_float), high_res_features=high_res_features, + ) + is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1)[..., None] + alpha = is_obj_appearing.to(obj_ptr.dtype) + object_score_logits = out_scale * alpha + out_bias + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +# Split attention with configurable input dims (for asymmetric cross-attention) +class SplitAttn(nn.Module): + def __init__(self, embed_dim, num_heads=1, kv_dim=None, internal_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + kv_dim = kv_dim or embed_dim + internal_dim = internal_dim or embed_dim + self.q_proj = operations.Linear(embed_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0): + if k is None: + k = q + if v is None: + v = k + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + out = optimized_attention(q, k, v, self.num_heads) + return self.out_proj(out) + + +class MemoryAttnLayer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.self_attn = SplitAttn(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitAttn(d_model, num_heads, kv_dim=kv_dim, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, memory, memory_pos=None, rope=None, num_k_exclude_rope=0): + x = x + self.self_attn(self.norm1(x), rope=rope) + mem_k = memory + memory_pos if memory_pos is not None else memory + x = x + self.cross_attn_image(self.norm2(x), mem_k, memory, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class MemoryAttnEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + MemoryAttnLayer(d_model, num_heads, kv_dim, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, src_pos=None, memory_pos=None, num_k_exclude_rope=0): + if src_pos is not None: + x = x + 0.1 * src_pos + + rope = self._rope.to(device=x.device) + for layer in self.layers: + x = layer(x, memory, memory_pos=memory_pos, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + return self.norm(x) + + +class MemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = MemoryAttnEncoder(d_model, num_heads, kv_dim, dim_ff, num_layers, device=device, dtype=dtype, operations=operations) + + +def _upscale_masks(output_upscaling, conv_s0, conv_s1, src_out, high_res_features): + """Shared upscaling for SAM mask decoders: deconv + high-res feature integration.""" + dc1, ln1, act1, dc2, act2 = output_upscaling + if high_res_features is not None: + upscaled = act1(ln1(dc1(src_out) + conv_s1(high_res_features[1]))) + upscaled = act2(dc2(upscaled) + conv_s0(high_res_features[0])) + else: + upscaled = act2(dc2(act1(ln1(dc1(src_out))))) + return upscaled + + +class SAMMaskDecoder(nn.Module): + def __init__(self, d_model=256, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_mask_tokens = num_multimask_outputs + 1 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.iou_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(self.num_mask_tokens, d_model, device=device, dtype=dtype) + self.obj_score_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + + # Output upscaling: d_model -> d_model//4 -> d_model//8 at 4x resolution + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + + # High-res feature integration + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Per-mask hypernetwork MLPs + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_tokens) + ]) + + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_tokens, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False): + B = sparse_prompt_embeddings.shape[0] + ref = sparse_prompt_embeddings + # Token order: [obj_score(1), iou(1), mask(num_mask_tokens)] + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), + cast_to_input(self.mask_tokens.weight, ref)], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + src_flat = src.flatten(2).permute(0, 2, 1) + pos_flat = pos_src.flatten(2).permute(0, 2, 1) + + hs, src_out = self.transformer(src_flat, pos_flat, tokens) + + obj_score_token_out = hs[:, 0, :] + iou_token_out = hs[:, 1, :] + mask_tokens_out = hs[:, 2:2 + self.num_mask_tokens, :] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + hyper_in = torch.stack([ + mlp(mask_tokens_out[:, i, :]) for i, mlp in enumerate(self.output_hypernetworks_mlps) + ], dim=1) + + masks = (hyper_in @ upscaled.flatten(2)).view(B, self.num_mask_tokens, upscaled.shape[2], upscaled.shape[3]) + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) + + if multimask_output: + out_masks = masks[:, 1:] + out_iou = iou_pred[:, 1:] + out_tokens = mask_tokens_out[:, 1:] + else: + out_masks = masks[:, 0:1] + out_iou = iou_pred[:, 0:1] + out_tokens = mask_tokens_out[:, 0:1] + + if return_all: + return out_masks, out_iou, out_tokens, object_score_logits + return out_masks, out_iou + + +class SAMPromptEncoder(nn.Module): + def __init__(self, d_model=256, image_embedding_size=(72, 72), input_image_size=(1008, 1008), device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = d_model + self.image_embedding_size = image_embedding_size + self.input_image_size = input_image_size + + self.pe_layer = PositionEmbeddingRandom(d_model // 2) + self.point_embeddings = nn.ModuleList([ + operations.Embedding(1, d_model, device=device, dtype=dtype) for _ in range(4) + ]) + self.not_a_point_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.mask_downscaling = nn.Sequential( + operations.Conv2d(1, 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(4, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(4, 16, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(16, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(16, d_model, kernel_size=1, device=device, dtype=dtype), + ) + self.no_mask_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + def get_dense_pe(self): + return self.pe_layer(self.image_embedding_size) + + def forward(self, points=None, boxes=None, masks=None): + ref = points[0] if points is not None else boxes if boxes is not None else masks + B = 1 + sparse = torch.empty((B, 0, self.embed_dim), device=ref.device, dtype=ref.dtype) + + if points is not None: + coords, labels = points + B = coords.shape[0] + # Pad with an extra point (label=-1) when no boxes are provided (matching reference) + if boxes is None: + coords = torch.cat([coords, torch.zeros(B, 1, 2, device=coords.device, dtype=coords.dtype)], dim=1) + labels = torch.cat([labels, -torch.ones(B, 1, device=labels.device, dtype=labels.dtype)], dim=1) + pe = self.pe_layer.forward_with_coords(coords + 0.5, self.input_image_size) + for i in range(4): + pe[labels == i] += cast_to_input(self.point_embeddings[i].weight, ref) + invalid = (labels == -1) + pe[invalid] = 0.0 + pe[invalid] += cast_to_input(self.not_a_point_embed.weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), pe], dim=1) + + if boxes is not None: + B = boxes.shape[0] + corners = self.pe_layer.forward_with_coords((boxes.reshape(-1, 2, 2) + 0.5), self.input_image_size) + corners[:, 0] += cast_to_input(self.point_embeddings[2].weight, ref) + corners[:, 1] += cast_to_input(self.point_embeddings[3].weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), corners], dim=1) + + if masks is not None: + dense = self.mask_downscaling(masks) + else: + dense = cast_to_input(self.no_mask_embed.weight, ref).reshape(1, -1, 1, 1).expand( + B, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + + return sparse, dense + + +class CXBlock(nn.Module): + def __init__(self, dim=256, kernel_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.dwconv = operations.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, device=device, dtype=dtype) + self.pwconv1 = operations.Linear(dim, 4 * dim, device=device, dtype=dtype) + self.pwconv2 = operations.Linear(4 * dim, dim, device=device, dtype=dtype) + self.gamma = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + + def forward(self, x): + residual = x + x = self.dwconv(x).permute(0, 2, 3, 1) + x = self.pwconv2(F.gelu(self.pwconv1(self.norm(x)))) + x.mul_(cast_to_input(self.gamma, x)) + return residual + x.permute(0, 3, 1, 2) + + +class MaskDownSampler(nn.Module): + def __init__(self, out_dim=256, in_chans=1, channels=None, interpol_size=(1152, 1152), device=None, dtype=None, operations=None): + super().__init__() + self.interpol_size = list(interpol_size) if interpol_size else None + if channels is None: + channels = [4, 16, 64, out_dim] # SAM3 default + LN2d = LayerNorm2d_op(operations) + layers = [] + prev = in_chans + for ch in channels: + layers += [operations.Conv2d(prev, ch, kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + LN2d(ch, device=device, dtype=dtype), nn.GELU()] + prev = ch + layers.append(operations.Conv2d(prev, out_dim, kernel_size=1, device=device, dtype=dtype)) + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + if self.interpol_size is not None and list(x.shape[-2:]) != self.interpol_size: + x = F.interpolate(x, size=self.interpol_size, mode="bilinear", align_corners=False, antialias=True) + return self.encoder(x) + + +class Fuser(nn.Module): + def __init__(self, dim=256, num_layers=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.Sequential(*[CXBlock(dim, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +# --- SAM3.1 Multiplex components --- + +class DecoupledMemoryAttnLayer(nn.Module): + """Decoupled cross-attention layer for SAM3.1: fuses image and memory projections.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + # Self-attention projections (flat, not nested) + self.self_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Cross-attention projections + self.cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Image cross-attention (q/k only, fused with cross_attn) + self.image_cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.image_cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # FFN + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, image, x, memory_image, memory, memory_image_pos=None, + rope=None, num_k_exclude_rope=0): + # Self-attention with RoPE + normed = self.norm1(x) + q = self.self_attn_q_proj(normed) + k = self.self_attn_k_proj(normed) + v = self.self_attn_v_proj(normed) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + + # Decoupled cross-attention: fuse image and memory projections + normed = self.norm2(x) + q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(normed) + k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory) + if memory_image_pos is not None: + k = k + memory_image_pos + v = self.cross_attn_v_proj(memory) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + + # FFN + x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) + return image, x + + +class DecoupledMemoryEncoder(nn.Module): + """Memory attention encoder for SAM3.1 with decoupled cross-attention.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + DecoupledMemoryAttnLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, memory_pos=None, src_pos=None, num_k_exclude_rope=0, + memory_image=None, memory_image_pos=None): + image = x # constant residual for decoupled cross-attention + output = x + if src_pos is not None: + output = output + 0.1 * src_pos + + B, _, C = x.shape + rope = self._rope.to(device=x.device) + + # memory_image: raw backbone features from past frames for decoupled cross-attention + if memory_image is None: + # Fallback: use spatial portion of memory (without obj pointers) + num_spatial = memory.shape[1] - num_k_exclude_rope + memory_image = memory[:, :num_spatial] + memory_image_pos = memory_pos[:, :num_spatial] if memory_pos is not None else None + # Pad memory_image to match memory length (zeros for obj pointer tokens) + if memory_image.shape[1] < memory.shape[1]: + pad_len = memory.shape[1] - memory_image.shape[1] + pad = torch.zeros(B, pad_len, C, device=memory.device, dtype=memory.dtype) + memory_image = torch.cat([memory_image, pad], dim=1) + if memory_image_pos is not None: + ptr_pos = memory_pos[:, -pad_len:] if memory_pos is not None else torch.zeros_like(pad) + memory_image_pos = torch.cat([memory_image_pos, ptr_pos], dim=1) + + for layer in self.layers: + image, output = layer(image, output, memory_image, memory, + memory_image_pos=memory_image_pos, rope=rope, + num_k_exclude_rope=num_k_exclude_rope) + + return self.norm(output) + + +class DecoupledMemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = DecoupledMemoryEncoder(d_model, num_heads, dim_ff, num_layers, + device=device, dtype=dtype, operations=operations) + + +class MemoryBackbone(nn.Module): + """Memory encoder: downsamples mask, fuses with pixel features, optionally compresses.""" + + def __init__(self, d_model=256, out_dim=None, in_chans=1, channels=None, device=None, dtype=None, operations=None): + super().__init__() + self.mask_downsampler = MaskDownSampler(d_model, in_chans=in_chans, channels=channels, device=device, dtype=dtype, operations=operations) + self.pix_feat_proj = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.fuser = Fuser(d_model, num_layers=2, device=device, dtype=dtype, operations=operations) + self.has_out_proj = out_dim is not None and out_dim != d_model + if self.has_out_proj: + self.out_proj = operations.Conv2d(d_model, out_dim, kernel_size=1, device=device, dtype=dtype) + feat_dim = out_dim + else: + feat_dim = d_model + self.position_encoding = PositionEmbeddingSine(num_pos_feats=feat_dim, normalize=True) + + def forward(self, image_features, mask_for_mem, skip_mask_sigmoid=False): + if not skip_mask_sigmoid: + mask_for_mem = mask_for_mem.sigmoid() + mask_features = self.mask_downsampler(cast_to_input(mask_for_mem, image_features)) + if mask_features.shape[-2:] != image_features.shape[-2:]: + mask_features = F.interpolate(mask_features, size=image_features.shape[-2:], mode="bilinear", align_corners=False) + features = self.pix_feat_proj(image_features) + mask_features + features = self.fuser(features) + if self.has_out_proj: + features = self.out_proj(features) + pos = cast_to_input(self.position_encoding(features), features) + return {"vision_features": features, "vision_pos_enc": [pos]} + + +class MultiplexMaskDecoder(nn.Module): + """SAM mask decoder for SAM3.1 multiplex: predicts masks for num_multiplex objects simultaneously. + + Uses multimask_outputs_only=True: num_mask_output_per_object = num_multimask_outputs (no +1). + Hypernetwork MLPs are shared across multiplex objects. + Token order: [obj_score_token(M), iou_token(M), mask_tokens(M*T)]. + """ + + def __init__(self, d_model=256, num_multiplex=16, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_multiplex = num_multiplex + self.num_mask_output_per_object = num_multimask_outputs # 3 (multimask_outputs_only) + total_mask_tokens = num_multiplex * self.num_mask_output_per_object # 48 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.obj_score_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.iou_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(total_mask_tokens, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Shared across all multiplex objects (one per mask output) + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_output_per_object) + ]) + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_output_per_object, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False, extra_per_object_embeddings=None): + B = sparse_prompt_embeddings.shape[0] + M = self.num_multiplex + T = self.num_mask_output_per_object + + # Token order: [obj_score(M), iou(M), mask(M*T)] + ref = sparse_prompt_embeddings + mask_tokens = cast_to_input(self.mask_tokens.weight, ref) + if extra_per_object_embeddings is not None: + mask_tokens = mask_tokens.view(1, M, T, -1).expand(B, -1, -1, -1) + extra_per_object_embeddings.unsqueeze(2) + mask_tokens = mask_tokens.flatten(1, 2) # [B, M*T, C] + other_tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref)], dim=0).unsqueeze(0).expand(B, -1, -1) + tokens = torch.cat([other_tokens, mask_tokens, sparse_prompt_embeddings], dim=1) + else: + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), mask_tokens], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + hs, src_out = self.transformer(src.flatten(2).permute(0, 2, 1), pos_src.flatten(2).permute(0, 2, 1), tokens) + + # Parse output tokens + obj_score_token_out = hs[:, :M] + iou_token_out = hs[:, M:2 * M] + mask_tokens_out = hs[:, 2 * M:2 * M + M * T] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + # Reshape mask tokens to [B, M, T, C] and apply shared hypernetwork MLPs per mask output index + mask_tokens_2d = mask_tokens_out.view(B, M, T, -1) + hyper_in = torch.stack([ + self.output_hypernetworks_mlps[i](mask_tokens_2d[:, :, i, :]) # [B, M, C//8] + for i in range(T) + ], dim=2) # [B, M, T, C//8] + + # Generate masks: [B, M*T, H*W] -> [B, M, T, H, W] + masks = torch.bmm(hyper_in.flatten(1, 2), upscaled.flatten(2)).view(b, M, T, upscaled.shape[2], upscaled.shape[3]) + + # IoU and object scores + iou_pred = self.iou_prediction_head(iou_token_out).view(b, M, T) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) # [B, M, 1] + + # multimask_outputs_only: always output all T masks (no singlemask token) + sam_tokens_out = mask_tokens_2d[:, :, 0:1] # [B, M, 1, C] + + if return_all: + return masks, iou_pred, sam_tokens_out, object_score_logits + return masks, iou_pred + + +class SAM3Tracker(nn.Module): + def __init__(self, d_model=256, mem_dim=64, num_maskmem=7, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + + # Memory attention transformer + self.transformer = MemoryTransformer(d_model, num_heads=1, kv_dim=mem_dim, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + # SAM components + self.sam_mask_decoder = SAMMaskDecoder(d_model, device=device, dtype=dtype, operations=operations) + self.sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone + self.maskmem_backbone = MemoryBackbone(d_model, out_dim=mem_dim, device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + self.register_buffer("no_mem_pos_enc", torch.zeros(1, 1, d_model, device=device, dtype=dtype)) # checkpoint key, unused in forward + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(1, mem_dim, device=device, dtype=dtype)) + self.no_obj_ptr = nn.Parameter(torch.zeros(1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + + # Mask downsample: Conv2d stride 4 to reduce GT mask to SAM logit scale + self.mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Config + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 20.0 + self.sigmoid_bias_for_mem_enc = -10.0 + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(cast_to_input(self.no_obj_ptr, obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.mask_downsample, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames): + """Fuse current frame features with memory from previous frames.""" + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + # First conditioning frame: no memory yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, _, _, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx: + pos_and_ptrs.append(((frame_idx - t), out["obj_ptr"].to(device))) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device))) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [B, N, C=256] + + # Temporal position encoding for pointers + obj_pos = compute_tpos_enc( + list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype + ) # [N, mem_dim=64] + obj_pos = obj_pos.unsqueeze(0).expand(B, -1, -1) # [B, N, 64] + + # Split each 256-dim pointer into 4 x 64-dim tokens + if self.mem_dim < C: + N = obj_ptrs.shape[1] + obj_ptrs = obj_ptrs.view(B, N, C // self.mem_dim, self.mem_dim) # [B, N, 4, 64] + obj_ptrs = obj_ptrs.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, C // self.mem_dim, -1) + obj_pos = obj_pos.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + # No memory available yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + # Concatenate all memory and position encodings [B, total_mem, mem_dim=64] + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Run memory attention encoder + pix_feat = current_vision_feats[-1] # [B, HW, C] + src_pos = current_vision_pos_embeds[-1] # [B, HW, C] + + pix_feat_with_mem = self.transformer.encoder( + x=pix_feat, + memory=memory, + src_pos=src_pos, + memory_pos=memory_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False): + """Encode predicted mask into memory features.""" + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + maskmem_out = self.maskmem_backbone(pix_feat, mask_for_mem, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed for occluded objects + alpha = (object_score_logits > 0).to(maskmem_features.dtype)[..., None, None] + no_obj = cast_to_input(self.no_obj_embed_spatial, maskmem_features)[..., None, None].expand_as(maskmem_features) + return maskmem_features + (1 - alpha) * no_obj, maskmem_pos_enc + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, mask_inputs, output_dict, + num_frames, point_inputs=None): + """Track one frame: fuse with memory, predict mask, encode memory.""" + current_out = {} + + # High-res features for SAM head [stride-8, stride-4] + if len(current_vision_feats) > 1: + high_res_features = [ + x.view(x.shape[0], feat_sizes[i][0], feat_sizes[i][1], -1).permute(0, 3, 1, 2) + for i, x in enumerate(current_vision_feats[:-1]) + ] + else: + high_res_features = None + + # Top-level feature for memory + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use mask directly + pix_feat = to_spatial(current_vision_feats[-1], H, W) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # Track frame: fuse with memory, then SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + output_dict=output_dict, + num_frames=num_frames, + ) + # Use multimask for point prompts on init frames (picks best of 3 candidates) + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Clean low-res masks: remove sprinkles and fill holes + low_res_masks = fill_holes_in_mask_scores(low_res_masks, max_area=200) + high_res_masks = F.interpolate(low_res_masks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + # Encode memory + if self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, _, _ = _compute_backbone(backbone_fn, frame, frame_idx) + # SAM3: drop last FPN level + return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1] + + def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None): + """Track one object, computing backbone per frame to save VRAM.""" + N = images.shape[0] + device, dt = images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame( + backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx) + mask_input = None + if frame_idx == 0: + mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > 0.5).to(dt) + + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=(frame_idx == 0), + current_vision_feats=vision_feats, current_vision_pos_embeds=vision_pos, + feat_sizes=feat_sizes, mask_inputs=mask_input, output_dict=output_dict, num_frames=N) + + if frame_idx == 0: + output_dict["cond_frame_outputs"][frame_idx] = current_out + else: + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + # Move masks to CPU immediately to free VRAM + all_masks.append(current_out["pred_masks_high_res"].to(comfy.model_management.intermediate_device())) + if pbar is not None: + pbar.update(1) + + return torch.cat(all_masks, dim=0) # [N, 1, H, W] + + def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs): + """Track one or more objects across video frames. + + Args: + backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame + images: [N, 3, 1008, 1008] video frames + initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object) + pbar: optional progress bar + + Returns: + [N, N_obj, image_size, image_size] predicted mask logits per frame per object + """ + N_obj = initial_masks.shape[0] + per_object = [] + for obj_idx in range(N_obj): + obj_masks = self._track_single_object( + backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar) + per_object.append(obj_masks) + + return torch.cat(per_object, dim=1) # [N, N_obj, H, W] + + +class SAM31Tracker(nn.Module): + """SAM3.1 multiplex tracker: decoupled memory attention, dual decoder, 16-object multiplex.""" + + def __init__(self, d_model=256, mem_dim=256, num_maskmem=7, num_multiplex=16, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.num_multiplex = num_multiplex + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 2.0 + self.sigmoid_bias_for_mem_enc = -1.0 + + # Memory attention (decoupled cross-attention, 8 heads matching reference) + self.transformer = DecoupledMemoryTransformer(d_model, num_heads=8, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + + # Propagation decoder (multiplex: 16 objects, multimask_outputs_only) + self.sam_mask_decoder = MultiplexMaskDecoder(d_model, num_multiplex, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + # Interactive decoder (single object, same as SAM3) + self.interactive_sam_mask_decoder = SAMMaskDecoder(d_model, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + self.interactive_sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone (mem_dim=256, no out_proj compression) + self.maskmem_backbone = MemoryBackbone(d_model, in_chans=num_multiplex * 2, channels=[16, 64, 256, 1024], + device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(num_multiplex, mem_dim, device=device, dtype=dtype)) + self.interactivity_no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + self.no_obj_ptr_linear = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.interactive_obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + # Interactive mask downsample + self.interactive_mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Multiplex validity embeddings + self.output_valid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + self.output_invalid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + + # Position encoding for image (used by multiplex decoder) + self.image_pe_layer = PositionEmbeddingRandom(d_model // 2) + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(self.no_obj_ptr_linear(obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.interactive_sam_prompt_encoder, self.interactive_sam_mask_decoder, + self.interactive_obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.interactive_mask_downsample, self.interactive_sam_prompt_encoder, + self.interactive_sam_mask_decoder, self.interactive_obj_ptr_proj, + self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, + current_vision_pos_embeds, feat_sizes, output_dict, num_frames, + multiplex_state=None): + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + num_buc = multiplex_state.num_buckets if multiplex_state is not None else None + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device, + collect_image_feats=True, tpos_v2=True, num_buckets=num_buc) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append(((frame_idx - t), ptr)) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append((t_diff, ptr)) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [num_buckets, N, M, C] + B_ptr = obj_ptrs.shape[0] + N_ptrs = obj_ptrs.shape[1] + M = obj_ptrs.shape[2] + obj_ptrs = obj_ptrs.reshape(B_ptr, N_ptrs * M, -1) + obj_pos = compute_tpos_enc(list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype) + obj_pos = obj_pos.unsqueeze(0).expand(B_ptr, -1, -1) + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, M, -1).reshape(B_ptr, N_ptrs * M, -1) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Expand vision features to num_buckets if memory has more buckets than B + mem_B = memory.shape[0] + x = current_vision_feats[-1] + x_pos = current_vision_pos_embeds[-1] + if x.shape[0] < mem_B: + x = x.expand(mem_B, -1, -1) + x_pos = x_pos.expand(mem_B, -1, -1) + + if len(to_cat_image_feat) > 0: + # Decoupled cross-attention: separate image features from memory + memory_image = cast_to_input(torch.cat(to_cat_image_feat, dim=1), x) + memory_image_pos = cast_to_input(torch.cat(to_cat_image_pos, dim=1), x) + if memory_image.shape[0] < mem_B: + memory_image = memory_image.expand(mem_B, -1, -1) + memory_image_pos = memory_image_pos.expand(mem_B, -1, -1) + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=cast_to_input(memory, x), + memory_pos=cast_to_input(memory_pos, x), + src_pos=cast_to_input(x_pos, x), + num_k_exclude_rope=num_obj_ptr_tokens, + memory_image=memory_image, + memory_image_pos=memory_image_pos, + ) + else: + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=memory, + memory_pos=memory_pos, + src_pos=x_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False, + multiplex_state=None, is_conditioning=False, cond_obj_mask=None): + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + # Mux masks: [N_obj, 1, H, W] -> [num_buckets, M, H, W] + mux_masks = multiplex_state.mux(mask_for_mem[:, 0]) + + # Conditioning channel: 1.0 = clean mask (trust it), 0.0 = propagation (noisy) + N_obj = mask_for_mem.shape[0] + cond_values = torch.full((N_obj,), 0.0, device=mask_for_mem.device, dtype=mask_for_mem.dtype) + if is_conditioning: + cond_values[:] = 1.0 + elif cond_obj_mask is not None: + cond_values[cond_obj_mask] = 1.0 + cond_spatial = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem[:, 0:1, :, :]).squeeze(1) + mux_cond = multiplex_state.mux(cond_spatial) # [num_buckets, M, H, W] + mux_input = torch.cat([mux_masks, mux_cond], dim=1) # [num_buckets, 2*M, H, W] + + maskmem_out = self.maskmem_backbone(pix_feat, mux_input, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed_spatial for occluded objects + is_obj = (object_score_logits > 0).float() # [N_obj, 1] + mux_is_obj = multiplex_state.mux(is_obj) # [num_buckets, M, 1] + no_obj_embed = cast_to_input(self.no_obj_embed_spatial, maskmem_features) # [M, C] + no_obj_spatial = no_obj_embed.unsqueeze(0)[..., None, None] # [1, M, C, 1, 1] + # Expand and sum across multiplex slots weighted by (1 - is_obj) + alpha = mux_is_obj[..., None, None] # [num_buckets, M, 1, 1, 1] + per_slot_no_obj = ((1 - alpha) * no_obj_spatial).sum(dim=1) # [num_buckets, C, 1, 1] + maskmem_features = maskmem_features + per_slot_no_obj.expand_as(maskmem_features) + + return maskmem_features, maskmem_pos_enc + + def _forward_propagation(self, backbone_features, high_res_features=None, multiplex_state=None): + """Propagation path using the multiplex SAM decoder (no prompts).""" + B = backbone_features.shape[0] + device = backbone_features.device + + # Suppression embeddings from valid object mask + valid_mask = cast_to_input(multiplex_state.get_valid_object_mask().unsqueeze(-1).float(), backbone_features) + output_valid = cast_to_input(self.output_valid_embed, backbone_features).unsqueeze(0) + output_invalid = cast_to_input(self.output_invalid_embed, backbone_features).unsqueeze(0) + extra_embed = valid_mask * output_valid + (1 - valid_mask) * output_invalid + + image_pe = self.image_pe_layer((backbone_features.shape[-2], backbone_features.shape[-1]), device=backbone_features.device) + image_pe = cast_to_input(image_pe, backbone_features) + + masks, iou_pred, sam_tokens_out, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=torch.empty(B, 0, self.d_model, device=device, dtype=backbone_features.dtype), + dense_prompt_embeddings=torch.zeros(B, self.d_model, *backbone_features.shape[-2:], device=device, dtype=backbone_features.dtype), + high_res_features=high_res_features, multimask_output=True, return_all=True, + extra_per_object_embeddings=extra_embed.expand(B, -1, -1), + ) + # masks: [B=num_buckets, M, T, H, W] + # Demux to per-object: [N_obj, T, H, W] + masks_obj = multiplex_state.demux(masks) + iou_obj = multiplex_state.demux(iou_pred) + score_obj = multiplex_state.demux(object_score_logits) + tokens_obj = multiplex_state.demux(sam_tokens_out) + + # Select best mask by IoU for each object + best_idx = torch.argmax(iou_obj, dim=-1) # [N_obj] + N_obj = masks_obj.shape[0] + obj_range = torch.arange(N_obj, device=device) + low_res_masks = masks_obj[obj_range, best_idx].unsqueeze(1) # [N_obj, 1, H, W] + # Suppress masks for objects with low confidence + is_obj = score_obj > 0 + low_res_masks = torch.where(is_obj[:, :, None, None], low_res_masks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_masks.dtype)) + high_res_masks = F.interpolate(low_res_masks.float(), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + # Object pointer: compute per-object, mux for storage as [num_buckets, M, C] + sam_token = tokens_obj[:, 0] # [N_obj, C] + obj_ptr = self.obj_ptr_proj(sam_token) + is_obj = (score_obj > 0).float() + no_obj = self.no_obj_ptr_linear(obj_ptr) + obj_ptr = is_obj * obj_ptr + (1 - is_obj) * no_obj + obj_ptr_muxed = multiplex_state.mux(obj_ptr) # [num_buckets, M, C] + + return low_res_masks, high_res_masks, obj_ptr_muxed, score_obj + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, + feat_sizes, mask_inputs, output_dict, num_frames, point_inputs=None, + interactive_high_res=None, interactive_backbone=None, propagation_high_res=None, + multiplex_state=None, run_mem_encoder=True): + current_out = {} + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use interactive features if available, else propagation + if interactive_backbone is not None: + pix_feat = interactive_backbone + # Add no_mem_embed for interactive path + pix_flat = pix_feat.flatten(2) + bf = pix_flat.permute(0, 2, 1) + cast_to_input(self.interactivity_no_mem_embed, pix_flat) + pix_feat = to_spatial(bf, H, W) + hi_res = interactive_high_res + else: + # Fallback: interactive backbone not available (e.g. called outside track_video). + # Propagation features work but may produce lower-quality conditioning. + pix_feat = to_spatial(current_vision_feats[-1], H, W) + hi_res = propagation_high_res + sam_outputs = self._use_mask_as_output(pix_feat, hi_res, mask_inputs) + elif point_inputs is not None: + # Interactive path: use interactive SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + hi_res = interactive_high_res if interactive_high_res is not None else propagation_high_res + num_pts = point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, point_inputs=point_inputs, + high_res_features=hi_res, multimask_output=multimask_output, + ) + else: + # Propagation path: use multiplex SAM decoder with propagation features + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + sam_outputs = self._forward_propagation(pix_feat_with_mem, propagation_high_res, + multiplex_state=multiplex_state) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Mux obj_ptr if it came from interactive path (shape [B, C]) vs propagation ([num_buckets, M, C]) + if multiplex_state is not None and obj_ptr.dim() == 2: + obj_ptr = multiplex_state.mux(obj_ptr) # [N_obj, C] -> [num_buckets, M, C] + + # Encode memory (can be deferred with run_mem_encoder=False) + if run_mem_encoder and self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + multiplex_state=multiplex_state, + is_conditioning=(mask_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + # Store propagation image features for decoupled memory attention + current_out["image_features"] = current_vision_feats[-1] # [B, HW, C] + current_out["image_pos_enc"] = current_vision_pos_embeds[-1] # [B, HW, C] + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, features, trunk_out = _compute_backbone(backbone_fn, frame, frame_idx) + return vision_feats, vision_pos, feat_sizes, list(features[:-1]), trunk_out + + @staticmethod + def _suppress_recently_occluded(low_res_masks, last_occluded, frame_idx, threshold=0.3): + """Suppress overlapping masks for objects that were most recently occluded. + Prevents corrupted masks from occluded objects from contaminating other objects.""" + N_obj = low_res_masks.shape[0] + if N_obj <= 1: + return low_res_masks + binary = low_res_masks[:, 0] > 0 # [N_obj, H, W] + iou = _compute_mask_overlap(low_res_masks[:, 0], low_res_masks[:, 0]) + overlapping = torch.triu(iou >= threshold, diagonal=1) # [N, N] upper triangle + last_occ_i = last_occluded.unsqueeze(1) # [N, 1] + last_occ_j = last_occluded.unsqueeze(0) # [1, N] + # Suppress the more recently occluded object in each overlapping pair + suppress_i = overlapping & (last_occ_i > last_occ_j) & (last_occ_j > -1) + suppress_j = overlapping & (last_occ_j > last_occ_i) & (last_occ_i > -1) + to_suppress = suppress_i.any(dim=1) | suppress_j.any(dim=0) + # Update last_occluded for occluded/suppressed objects + is_empty = ~binary.any(dim=(-1, -2)) + newly_occluded = is_empty | to_suppress + last_occluded[newly_occluded] = frame_idx + # Suppress masks + low_res_masks[to_suppress] = -10.0 + return low_res_masks + + def _deferred_memory_encode(self, current_out, N_obj, vision_feats, feat_sizes, mux_state, device, + cond_obj_mask=None): + """Deferred memory encoding for propagation frames. cond_obj_mask: per-object bool for conditioning.""" + low_res_masks = current_out["pred_masks"] # [N_obj, 1, H_low, W_low] + + if N_obj > 1: + lr = low_res_masks.squeeze(1) # [N_obj, H, W] + max_obj = torch.argmax(lr, dim=0, keepdim=True) + batch_inds = torch.arange(N_obj, device=device)[:, None, None] + pixel_nol = torch.where(max_obj == batch_inds, lr, torch.clamp(lr, max=-10.0)) + area_before = (lr > 0).sum(dim=(-1, -2)).float().clamp(min=1) + area_after = (pixel_nol > 0).sum(dim=(-1, -2)).float() + shrink_ok = (area_after / area_before) >= 0.3 + low_res_masks = torch.where( + shrink_ok[:, None, None, None].expand_as(low_res_masks), + low_res_masks, torch.clamp(low_res_masks, max=-10.0)) + + interpol_size = self.maskmem_backbone.mask_downsampler.interpol_size + mem_masks = F.interpolate(low_res_masks, size=interpol_size, + mode="bilinear", align_corners=False) + + obj_scores = torch.where( + (mem_masks > 0).any(dim=(-1, -2)), 10.0, -10.0) + + pix_feat = to_spatial(vision_feats[-1], feat_sizes[-1][0], feat_sizes[-1][1]) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=mem_masks, + object_score_logits=obj_scores, + multiplex_state=mux_state, cond_obj_mask=cond_obj_mask) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + + def _add_detected_objects(self, new_masks, mux_state, vision_feats, feat_sizes, current_out): + """Grow MultiplexState with new detections, merge masks, re-encode memory. Modifies current_out.""" + n_old = mux_state.total_valid_entries + mux_state.add_objects(new_masks.shape[0]) + N_obj = mux_state.total_valid_entries + # Stored memory with old bucket counts is padded at read time by _pad_to_buckets + for k in ("pred_masks", "pred_masks_high_res"): + det = F.interpolate(new_masks.unsqueeze(1), size=current_out[k].shape[-2:], + mode="bilinear", align_corners=False) + current_out[k] = torch.cat([current_out[k], det], dim=0) + if self.num_maskmem > 0: + # Mark new objects as conditioning (clean detection masks) so model trusts them + cond_mask = torch.zeros(N_obj, dtype=torch.bool, device=new_masks.device) + cond_mask[n_old:] = True + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, + mux_state, new_masks.device, cond_obj_mask=cond_mask) + + def _condition_with_masks(self, masks, frame_idx, vision_feats, vision_pos, feat_sizes, + high_res_prop, output_dict, N, mux_state, backbone_obj, frame, + trunk_out, threshold=0.5): + """Condition tracker with masks on a frame.""" + mask_input = F.interpolate(masks if masks.dim() == 4 else masks.unsqueeze(1), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > threshold).to(masks.dtype) + hi_res = lo_feat = None + if backbone_obj is not None and backbone_obj.multiplex: + _, _, itf, _ = backbone_obj(frame, tracker_mode="interactive", cached_trunk=trunk_out, tracker_only=True) + hi_res, lo_feat = itf[:-1], itf[-1] + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=mask_input, + output_dict=output_dict, num_frames=N, interactive_high_res=hi_res, + interactive_backbone=lo_feat, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=True) + output_dict["cond_frame_outputs"][frame_idx] = current_out + return current_out + + def _match_and_add_detections(self, det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects=0, + keep_alive=None): + """Match detections against tracked masks, add new objects, recondition degraded tracks. + Updates keep_alive counters: +1 for matched tracks, -1 for unmatched.""" + N_obj = mux_state.total_valid_entries + if det_masks.shape[0] == 0: + if keep_alive is not None: + for i in range(N_obj): + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + return [] + + # Match at low-res (like reference) + trk_masks = current_out["pred_masks"][:, 0] # [N_obj, H_low, W_low] + det_resized = F.interpolate(det_masks.unsqueeze(1), size=trk_masks.shape[-2:], + mode="bilinear", align_corners=False)[:, 0] + overlap = _compute_mask_overlap(det_resized, trk_masks) + + # Update keep_alive and find matched tracks + matched = set() + if overlap.shape[1] > 0: + matched = set((overlap >= 0.5).any(dim=0).nonzero(as_tuple=True)[0].tolist()) + if keep_alive is not None: + for i in range(N_obj): + if i in matched: + keep_alive[i] = min(8, keep_alive.get(i, 0) + 1) + else: + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + + # Recondition: high-confidence detections (>=0.8) with high overlap refresh tracked masks + reconditioned = False + if det_scores is not None and overlap.shape[1] > 0: + HIGH_CONF = 0.8 + for det_idx in range(overlap.shape[0]): + if det_scores[det_idx] < HIGH_CONF: + continue + best_trk = overlap[det_idx].argmax().item() + if overlap[det_idx, best_trk] >= 0.5: + # Replace tracked mask with fresh detection mask + current_out["pred_masks"][best_trk] = det_resized[det_idx].unsqueeze(0) + det_hr = F.interpolate(det_masks[det_idx:det_idx+1].unsqueeze(1), + size=current_out["pred_masks_high_res"].shape[-2:], + mode="bilinear", align_corners=False) + current_out["pred_masks_high_res"][best_trk] = det_hr[0] + reconditioned = True + + # Re-encode memory if any tracks were reconditioned + if reconditioned and self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + + # Add new detections (not matching any track) + if max_objects > 0 and N_obj >= max_objects: + return [] + max_overlap = overlap.max(dim=1)[0] if overlap.shape[1] > 0 else torch.zeros(overlap.shape[0], device=device) + new_dets = max_overlap < 0.5 + if new_dets.any(): + if max_objects > 0: + slots = max_objects - N_obj + new_dets = new_dets & (torch.cumsum(new_dets.int(), 0) <= slots) + self._add_detected_objects(det_masks[new_dets], mux_state, + vision_feats, feat_sizes, current_out) + if keep_alive is not None: + for i in range(N_obj, mux_state.total_valid_entries): + keep_alive[i] = 1 + return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item() + return [] + + def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1, + backbone_obj=None, pbar=None): + """Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits.""" + N, device, dt = images.shape[0], images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + idev = comfy.model_management.intermediate_device() + mux_state = None + if initial_masks is not None: + mux_state = MultiplexState(initial_masks.shape[0], self.num_multiplex, device, dt) + obj_scores = [] # per-object detection score (1.0 for initial masks) + keep_alive = {} if detect_fn is not None else None + last_occluded = torch.empty(0, device=device, dtype=torch.long) # per-object last occluded frame + + # Prefetch next frame's backbone on a separate CUDA stream + prefetch = False + backbone_stream = None + if comfy.model_management.is_device_cuda(device): + try: + backbone_stream = torch.cuda.Stream(device=device) + prefetch = True + except RuntimeError: + pass + cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0) + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb + + # Start next frame's backbone on separate stream (overlaps with current frame's work) + if prefetch and frame_idx + 1 < N: + backbone_stream.wait_stream(torch.cuda.current_stream(device)) + with torch.cuda.stream(backbone_stream): + next_bb = self._compute_backbone_frame( + backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + # Per-frame detection with NMS (skip if no detect_fn, or interval/max not met) + det_masks = torch.empty(0, device=device) + det_scores = None + run_det = (detect_fn is not None + and frame_idx % max(detect_interval, 1) == 0 + and not (max_objects > 0 and mux_state is not None + and mux_state.total_valid_entries >= max_objects)) + if run_det: + det_out = detect_fn(trunk_out) + scores = det_out["scores"][0].sigmoid() + keep = scores > new_det_thresh + det_masks, det_scores = det_out["masks"][0][keep], scores[keep] + if det_masks.shape[0] > 1: + det_masks, det_scores = _nms_masks(det_masks, det_scores) + + if frame_idx == 0 and initial_masks is not None: + current_out = self._condition_with_masks( + initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos, + feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = [1.0] * mux_state.total_valid_entries + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 8 + elif mux_state is None or mux_state.total_valid_entries == 0: + if det_masks.shape[0] > 0: + if max_objects > 0: + det_scores = det_scores[:max_objects] + det_masks = det_masks[:max_objects] + mux_state = MultiplexState(det_masks.shape[0], self.num_multiplex, device, dt) + current_out = self._condition_with_masks( + det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, + output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = det_scores[:mux_state.total_valid_entries].tolist() + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 1 + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + # Skip to backbone advance at end of loop + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + continue + else: + N_obj = mux_state.total_valid_entries + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=False, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=None, + output_dict=output_dict, num_frames=N, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=False) + current_out["pred_masks"] = fill_holes_in_mask_scores( + current_out["pred_masks"], max_area=16) + if last_occluded.shape[0] == N_obj and N_obj > 1: + self._suppress_recently_occluded( + current_out["pred_masks"], last_occluded, frame_idx) + if self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + n_before = mux_state.total_valid_entries + new_obj_scores = self._match_and_add_detections(det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects, + keep_alive if run_det else None) + n_added = mux_state.total_valid_entries - n_before + if n_added > 0: + last_occluded = torch.cat([last_occluded, + torch.full((n_added,), -1, device=device, dtype=torch.long)]) + obj_scores.extend(new_obj_scores) + + masks_out = current_out["pred_masks_high_res"][:, 0] + if keep_alive is not None: + for i in range(masks_out.shape[0]): + if keep_alive.get(i, 0) <= 0: + masks_out[i] = NO_OBJ_SCORE + N_obj_now = mux_state.total_valid_entries if mux_state is not None else 0 + if N_obj_now > 0: + all_masks.append(pack_masks(masks_out).to(idev)) + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + + # Next frame's backbone + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + if not all_masks or all(m is None for m in all_masks): + return {"packed_masks": None, "n_frames": N, "scores": []} + + max_obj = max(m.shape[0] for m in all_masks if m is not None) + sample = next(m for m in all_masks if m is not None) + empty_packed = torch.zeros(max_obj, *sample.shape[1:], dtype=torch.uint8, device=sample.device) + for i, m in enumerate(all_masks): + if m is None: + all_masks[i] = empty_packed + elif m.shape[0] < max_obj: + pad = torch.zeros(max_obj - m.shape[0], *m.shape[1:], dtype=torch.uint8, device=m.device) + all_masks[i] = torch.cat([m, pad], dim=0) + return {"packed_masks": torch.stack(all_masks, dim=0), "n_frames": N, "scores": obj_scores} diff --git a/comfy/model_base.py b/comfy/model_base.py index 1c7695761..787ea1145 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,7 @@ import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model +import comfy.ldm.sam3.detector import comfy.model_management import comfy.patcher_extension @@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class SAM3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..724a241bf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "ernie" return dit_config + if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1 + if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "SAM3" + if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys: + dit_config["image_model"] = "SAM31" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return model_config def unet_prefix_from_state_dict(state_dict): + # SAM3: detector.* and tracker.* at top level, no common prefix + if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict): + return "" + candidates = ["model.diffusion_model.", #ldm/sgm models "model.model.", #audio models "net.", #cosmos diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..8886f32d5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class SAM3(supported_models_base.BASE): + unet_config = {"image_model": "SAM3"} + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + text_encoder_key_prefix = ["detector.backbone.language_backbone."] + unet_extra_prefix = "" + + def process_clip_state_dict(self, state_dict): + clip_keys = getattr(self, "_clip_stash", {}) + clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True) + clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.") + return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")} + + def process_unet_state_dict(self, state_dict): + self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k} + # SAM3.1: remap tracker.model.* -> tracker.* + for k in list(state_dict.keys()): + if k.startswith("tracker.model."): + state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k) + # SAM3.1: remove per-block freqs_cis buffers (computed dynamically) + for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]: + state_dict.pop(k) + # Split fused QKV projections + for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]: + t = state_dict.pop(k) + base, suffix = k.rsplit(".in_proj_", 1) + s = ".weight" if suffix == "weight" else ".bias" + d = t.shape[0] // 3 + state_dict[base + ".q_proj" + s] = t[:d] + state_dict[base + ".k_proj" + s] = t[d:2*d] + state_dict[base + ".v_proj" + s] = t[2*d:] + # Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer + for k in list(state_dict.keys()): + if "sam_mask_decoder.transformer." not in k: + continue + new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.") + if new_k != k: + state_dict[new_k] = state_dict.pop(k) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SAM3(self, device=device) + + def clip_target(self, state_dict={}): + import comfy.text_encoders.sam3_clip + return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper) + + +class SAM31(SAM3): + unet_config = {"image_model": "SAM31"} + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] models += [SVD_img2vid] diff --git a/comfy/text_encoders/sam3_clip.py b/comfy/text_encoders/sam3_clip.py new file mode 100644 index 000000000..11cb7d9db --- /dev/null +++ b/comfy/text_encoders/sam3_clip.py @@ -0,0 +1,97 @@ +import re +from comfy import sd1_clip + +SAM3_CLIP_CONFIG = { + "architectures": ["CLIPTextModel"], + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "max_position_embeddings": 32, + "projection_dim": 512, + "vocab_size": 49408, + "layer_norm_eps": 1e-5, + "eos_token_id": 49407, +} + + +class SAM3ClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options) + + +class SAM3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data) + self.disable_weights = True + + +def _parse_prompts(text): + """Split comma-separated prompts with optional :N max detections per category""" + text = text.replace("(", "").replace(")", "") + parts = [p.strip() for p in text.split(",") if p.strip()] + result = [] + for part in parts: + m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part) + if m: + text_part = m.group(1).strip() + val = m.group(2) + max_det = max(1, round(float(val))) + result.append((text_part, max_det)) + else: + result.append((part, 1)) + return result + + +class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip") + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + parsed = _parse_prompts(text) + if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1): + return super().tokenize_with_weights(text, return_word_ids, **kwargs) + # Tokenize each prompt part separately, store per-part batches and metadata + inner = getattr(self, self.clip) + per_prompt = [] + for prompt_text, max_det in parsed: + batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs) + per_prompt.append((batches, max_det)) + # Main output uses first prompt's tokens (for compatibility) + out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt} + return out + + +class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip") + + def encode_token_weights(self, token_weight_pairs): + per_prompt = token_weight_pairs.pop("sam3_per_prompt", None) + if per_prompt is None: + return super().encode_token_weights(token_weight_pairs) + + # Encode each prompt separately, pack into extra dict + inner = getattr(self, self.clip) + multi_cond = [] + first_pooled = None + for batches, max_det in per_prompt: + out = inner.encode_token_weights(batches) + cond, pooled = out[0], out[1] + extra = out[2] if len(out) > 2 else {} + if first_pooled is None: + first_pooled = pooled + multi_cond.append({ + "cond": cond, + "attention_mask": extra.get("attention_mask"), + "max_detections": max_det, + }) + + # Return first prompt as main (for non-SAM3 consumers), all prompts in metadata + main = multi_cond[0] + main_extra = {} + if main["attention_mask"] is not None: + main_extra["attention_mask"] = main["attention_mask"] + main_extra["sam3_multi_cond"] = multi_cond + return (main["cond"], first_pooled, main_extra) diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py new file mode 100644 index 000000000..5cf92ccb3 --- /dev/null +++ b/comfy_extras/nodes_sam3.py @@ -0,0 +1,529 @@ +""" +SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking. +""" + +from typing_extensions import override + +import json +import os +import torch +import torch.nn.functional as F +import comfy.model_management +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io, ui +import av +from fractions import Fraction + + +def _extract_text_prompts(conditioning, device, dtype): + """Extract list of (text_embeddings, text_mask) from conditioning.""" + cond_meta = conditioning[0][1] + multi = cond_meta.get("sam3_multi_cond") + prompts = [] + if multi is not None: + for entry in multi: + emb = entry["cond"].to(device=device, dtype=dtype) + mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None + if mask is None: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, entry.get("max_detections", 1))) + else: + emb = conditioning[0][0].to(device=device, dtype=dtype) + mask = cond_meta.get("attention_mask") + if mask is not None: + mask = mask.to(device) + else: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, 1)) + return prompts + + +def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations): + """Refine a coarse detector mask via SAM decoder, cropping to the detection box. + + Returns: [1, H, W] binary mask + """ + def _coarse_fallback(): + return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), + mode="bilinear", align_corners=False)[0] > 0).float() + + if iterations <= 0: + return _coarse_fallback() + + pad_frac = 0.1 + x1, y1, x2, y2 = box_xyxy.tolist() + bw, bh = x2 - x1, y2 - y1 + cx1 = max(0, int(x1 - bw * pad_frac)) + cy1 = max(0, int(y1 - bh * pad_frac)) + cx2 = min(W, int(x2 + bw * pad_frac)) + cy2 = min(H, int(y2 + bh * pad_frac)) + if cx2 <= cx1 or cy2 <= cy1: + return _coarse_fallback() + + crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3] + crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + crop_frame = crop_1008.to(device=device, dtype=dtype) + crop_h, crop_w = cy2 - cy1, cx2 - cx1 + + # Crop coarse mask and refine via SAM on the cropped image + mask_h, mask_w = coarse_mask.shape[-2:] + mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h) + mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h) + if mx2 <= mx1 or my2 <= my1: + return _coarse_fallback() + mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0) + for _ in range(iterations): + coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False) + mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input) + + refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False) + full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype) + full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop + coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False) + return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float() + + + +class SAM3_Detect(io.ComfyNode): + """Open-vocabulary detection and segmentation using text, box, or point prompts.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_Detect", + display_name="SAM3 Detect", + category="detection/", + search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], + inputs=[ + io.Model.Input("model", display_name="model"), + io.Image.Input("image", display_name="image"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"), + io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"), + io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01), + io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"), + io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"), + ], + outputs=[ + io.Mask.Output("masks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: + B, H, W, C = image.shape + image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + + # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. + # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). + def _boxes_to_tensor(box_list): + coords = [] + for d in box_list: + cx = (d["x"] + d["width"] / 2) / W + cy = (d["y"] + d["height"] / 2) / H + coords.append([cx, cy, d["width"] / W, d["height"] / H]) + return torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + + per_frame_boxes = None + if bboxes is not None: + if isinstance(bboxes, dict): + # Single box → same for all frames + shared = _boxes_to_tensor([bboxes]) + per_frame_boxes = [shared] * B + elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): + # list[list[dict]] → per-frame boxes + per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes] + # Pad to B if fewer frames provided + while len(per_frame_boxes) < B: + per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None) + elif isinstance(bboxes, list) and len(bboxes) > 0: + # list[dict] → same boxes for all frames + shared = _boxes_to_tensor(bboxes) + per_frame_boxes = [shared] * B + + # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) + pos_pts = json.loads(positive_coords) if positive_coords else [] + neg_pts = json.loads(negative_coords) if negative_coords else [] + has_points = len(pos_pts) > 0 or len(neg_pts) > 0 + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + # Build point inputs for tracker SAM decoder path + point_inputs = None + if has_points: + all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \ + [[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts] + all_labels = [1] * len(pos_pts) + [0] * len(neg_pts) + point_inputs = { + "point_coords": torch.tensor([all_coords], dtype=dtype, device=device), + "point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device), + } + + cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else [] + has_text = len(cond_list) > 0 + + # Run per-image through detector (text/boxes) and/or tracker (points) + all_bbox_dicts = [] + all_masks = [] + pbar = comfy.utils.ProgressBar(B) + + for b in range(B): + frame = image_in[b:b+1].to(device=device, dtype=dtype) + b_boxes = None + if per_frame_boxes is not None and per_frame_boxes[b] is not None: + b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype) + + frame_bbox_dicts = [] + frame_masks = [] + + # Point prompts: tracker SAM decoder path with iterative refinement + if point_inputs is not None: + mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Box prompts: SAM decoder path (segment inside each box) + if b_boxes is not None and not has_text: + for box_cxcywh in b_boxes[0]: + cx, cy, bw, bh = box_cxcywh.tolist() + # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners + sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], + [(cx + bw/2) * 1008, (cy + bh/2) * 1008]]], + device=device, dtype=dtype) + mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Text prompts: run detector per text prompt (each detects one category) + for text_embeddings, text_mask, max_det in cond_list: + results = sam3_model( + frame, text_embeddings=text_embeddings, text_mask=text_mask, + boxes=b_boxes, threshold=threshold, orig_size=(H, W)) + + pred_boxes = results["boxes"][0] + scores = results["scores"][0] + masks = results["masks"][0] + + probs = scores.sigmoid() + keep = probs > threshold + kept_boxes = pred_boxes[keep].cpu() + kept_scores = probs[keep].cpu() + kept_masks = masks[keep] + + order = kept_scores.argsort(descending=True)[:max_det] + kept_boxes = kept_boxes[order] + kept_scores = kept_scores[order] + kept_masks = kept_masks[order] + + for box, score in zip(kept_boxes, kept_scores): + frame_bbox_dicts.append({ + "x": float(box[0]), "y": float(box[1]), + "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), + "score": float(score), + }) + for m, box in zip(kept_masks, kept_boxes): + frame_masks.append(_refine_mask( + sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations)) + + all_bbox_dicts.append(frame_bbox_dicts) + if len(frame_masks) > 0: + combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W] + if individual_masks: + all_masks.append(combined) + else: + all_masks.append((combined > 0).any(dim=0).float()) + else: + if individual_masks: + all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device())) + else: + all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) + pbar.update(1) + + idev = comfy.model_management.intermediate_device() + all_masks = [m.to(idev) for m in all_masks] + mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks) + return io.NodeOutput(mask_out, all_bbox_dicts) + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + +class SAM3_VideoTrack(io.ComfyNode): + """Track objects across video frames using SAM3's memory-based tracker.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_VideoTrack", + display_name="SAM3 Video Track", + category="detection/", + search_aliases=["sam3", "video", "track", "propagate"], + inputs=[ + io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), + io.Model.Input("model", display_name="model"), + io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), + io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), + io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), + io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), + ], + outputs=[ + SAM3TrackData.Output("track_data", display_name="track_data"), + ], + ) + + @classmethod + def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput: + N, H, W, C = images.shape + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + frames = images[..., :3].movedim(-1, 1) + frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) + + init_masks = None + if initial_mask is not None: + init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(N) + + text_prompts = None + if conditioning is not None and len(conditioning) > 0: + text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)] + elif initial_mask is None: + raise ValueError("Either initial_mask or conditioning must be provided") + + result = sam3_model.forward_video( + images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, + new_det_thresh=detection_threshold, max_objects=max_objects, + detect_interval=detect_interval) + result["orig_size"] = (H, W) + return io.NodeOutput(result) + + +class SAM3_TrackPreview(io.ComfyNode): + """Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackPreview", + display_name="SAM3 Track Preview", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.Image.Input("images", display_name="images", optional=True), + io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05), + io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0), + ], + is_output_node=True, + ) + + COLORS = [ + (0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16), + (0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5), + (0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84), + ] + + # 5x3 bitmap font atlas for digits 0-9 [10, 5, 3] + _glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow) + + @staticmethod + def _get_glyphs(device, scale=3): + key = (device, scale) + if key in SAM3_TrackPreview._glyph_cache: + return SAM3_TrackPreview._glyph_cache[key] + atlas = torch.tensor([ + [[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]], + [[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]], + [[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]], + [[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]], + ], dtype=torch.bool) + glyphs, outlines = [], [] + for d in range(10): + g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1) + padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1)) + o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0) + glyphs.append(g.to(device)) + outlines.append(o.to(device)) + gh, gw = glyphs[0].shape + oh, ow = outlines[0].shape + SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow) + return SAM3_TrackPreview._glyph_cache[key] + + @staticmethod + def _draw_number_gpu(frame, number, cx, cy, color, scale=3): + """Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline.""" + H, W = frame.shape[:2] + device = frame.device + glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale) + color_t = torch.tensor(color, device=device, dtype=frame.dtype) + digs = [int(d) for d in str(number)] + total_w = len(digs) * (gw + scale) - scale + x0 = cx - total_w // 2 + y0 = cy - gh // 2 + for i, d in enumerate(digs): + dx = x0 + i * (gw + scale) + # Black outline + oy0, ox0 = y0 - 1, dx - 1 + osy1, osx1 = max(0, -oy0), max(0, -ox0) + osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0) + if osy2 > osy1 and osx2 > osx1: + fy1, fx1 = oy0 + osy1, ox0 + osx1 + frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0 + # Colored fill + sy1, sx1 = max(0, -y0), max(0, -dx) + sy2, sx2 = min(gh, H - y0), min(gw, W - dx) + if sy2 > sy1 and sx2 > sx1: + fy1, fx1 = y0 + sy1, dx + sx1 + frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t + + @classmethod + def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput: + + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + if images is not None: + H, W = images.shape[1], images.shape[2] + if packed is None: + N, N_obj = track_data["n_frames"], 0 + else: + N, N_obj = packed.shape[0], packed.shape[1] + + import uuid + gpu = comfy.model_management.get_torch_device() + temp_dir = folder_paths.get_temp_directory() + filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4" + filepath = os.path.join(temp_dir, filename) + with av.open(filepath, mode='w') as output: + stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000)) + stream.width = W + stream.height = H + stream.pix_fmt = 'yuv420p' + + frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8) + frame_np = frame_cpu.numpy() + if N_obj > 0: + colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)], + device=gpu, dtype=torch.float32) + grid_y = torch.arange(H, device=gpu).view(1, H, 1) + grid_x = torch.arange(W, device=gpu).view(1, 1, W) + for t in range(N): + if images is not None and t < images.shape[0]: + frame = images[t].clone() + else: + frame = torch.zeros(H, W, 3) + + if N_obj > 0: + frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool + frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0] + frame_gpu = frame.to(gpu) + bool_masks = frame_masks > 0.5 + any_mask = bool_masks.any(dim=0) + if any_mask.any(): + obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0) + color_overlay = colors_t[obj_idx_map] + mask_3d = any_mask.unsqueeze(-1) + frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu) + area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1) + cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area + cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area + has = area > 1 + scores = track_data.get("scores", []) + for obj_idx in range(N_obj): + if has[obj_idx]: + _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) + color = cls.COLORS[obj_idx % len(cls.COLORS)] + SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) + if obj_idx < len(scores) and scores[obj_idx] < 1.0: + SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), + _cx, _cy + 5 * 3 + 3, color, scale=2) + frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) + else: + frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) + + vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + output.mux(stream.encode(vframe.reformat(format='yuv420p'))) + output.mux(stream.encode(None)) + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)])) + + +class SAM3_TrackToMask(io.ComfyNode): + """Select tracked objects by index and output as mask.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackToMask", + display_name="SAM3 Track to Mask", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.String.Input("object_indices", display_name="object_indices", default="", + tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."), + ], + outputs=[ + io.Mask.Output("masks", display_name="masks"), + ], + ) + + @classmethod + def execute(cls, track_data, object_indices="") -> io.NodeOutput: + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + + if packed is None: + N = track_data["n_frames"] + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + N, N_obj = packed.shape[0], packed.shape[1] + + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + indices = [i for i in indices if 0 <= i < N_obj] + else: + indices = list(range(N_obj)) + + if not indices: + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + selected = packed[:, indices] + binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool + union = binary.any(dim=1, keepdim=True).float() + mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] + return io.NodeOutput(mask_out) + + +class SAM3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SAM3_Detect, + SAM3_VideoTrack, + SAM3_TrackPreview, + SAM3_TrackToMask, + ] + + +async def comfy_entrypoint() -> SAM3Extension: + return SAM3Extension() diff --git a/nodes.py b/nodes.py index bb38e07b8..fb83da896 100644 --- a/nodes.py +++ b/nodes.py @@ -2459,6 +2459,7 @@ async def init_builtin_extra_nodes(): "nodes_curve.py", "nodes_rtdetr.py", "nodes_frame_interpolation.py", + "nodes_sam3.py" ] import_failed = [] From 3cdc0d523f080deb22fee24bfb0080180cde4f6e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 23 Apr 2026 08:47:33 +0300 Subject: [PATCH 19/35] [Partner Nodes] GPTImage: fix price badges, add new resolutions (#13519) * fix(api-nodes): fixed price badges, add new resolutions Signed-off-by: bigcat88 * proper calculate the total run cost when "n > 1" Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_openai.py | 59 +++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 90a29c2f2..bbb758068 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -357,6 +357,10 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 +def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): @classmethod @@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "size", default="auto", - options=["auto", "1024x1024", "1024x1536", "1536x1024"], + options=[ + "auto", + "1024x1024", + "1024x1536", + "1536x1024", + "2048x2048", + "2048x1152", + "1152x2048", + "3840x2160", + "2160x3840", + ], tooltip="Image size", optional=True, ), @@ -427,7 +441,7 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gpt-image-1", "gpt-image-1.5", 'gpt-image-2'], + options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"], default="gpt-image-2", optional=True, ), @@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]), expr=""" ( $ranges := { - "low": [0.011, 0.02], - "medium": [0.046, 0.07], - "high": [0.167, 0.3] + "gpt-image-1": { + "low": [0.011, 0.02], + "medium": [0.042, 0.07], + "high": [0.167, 0.25] + }, + "gpt-image-1.5": { + "low": [0.009, 0.02], + "medium": [0.034, 0.062], + "high": [0.133, 0.22] + }, + "gpt-image-2": { + "low": [0.0048, 0.012], + "medium": [0.041, 0.112], + "high": [0.165, 0.43] + } }; - $range := $lookup($ranges, widgets.quality); - $n := widgets.n; + $range := $lookup($lookup($ranges, widgets.model), widgets.quality); + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; ($n = 1) - ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}} : { "type":"range_usd", - "min_usd": $range[0], - "max_usd": $range[1], - "format": { "suffix": " x " & $string($n) & "/Run" } + "min_usd": $range[0] * $n, + "max_usd": $range[1] * $n, + "format": { "suffix": "/Run", "approximate": true } } ) """, @@ -483,12 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode): if mask is not None and image is None: raise ValueError("Cannot use a mask without an input image") + if model in ("gpt-image-1", "gpt-image-1.5"): + if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"): + raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model") + if model == "gpt-image-1": price_extractor = calculate_tokens_price_image_1 elif model == "gpt-image-1.5": price_extractor = calculate_tokens_price_image_1_5 elif model == "gpt-image-2": - price_extractor = calculate_tokens_price_image_1_5 + price_extractor = calculate_tokens_price_image_2_0 + if background == "transparent": + raise ValueError("Transparent background is not supported for GPT Image 2 model") else: raise ValueError(f"Unknown model: {model}") From 5edbdf4364c6c89c3c6a5c6630807b59cb7652ba Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Thu, 23 Apr 2026 22:51:20 +0800 Subject: [PATCH 20/35] chore: update workflow templates to v0.9.61 (#13533) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a25bc0667..8a6ecf6d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.14 -comfyui-workflow-templates==0.9.59 +comfyui-workflow-templates==0.9.61 comfyui-embedded-docs==0.4.3 torch torchsde From 2a14e1e96afdb8ca744663e3f3f5970c5d023f5b Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Thu, 23 Apr 2026 23:15:47 +0800 Subject: [PATCH 21/35] chore: update embedded docs to v0.4.4 (#13535) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8a6ecf6d8..419124f48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.42.14 comfyui-workflow-templates==0.9.61 -comfyui-embedded-docs==0.4.3 +comfyui-embedded-docs==0.4.4 torch torchsde torchvision From abf3d56f27948b122dbcba35847b59e5ff299030 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:49:54 +0300 Subject: [PATCH 22/35] add 4K resolution to Kling nodes (#13536) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_kling.py | 91 ++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 21 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9a37ccc53..709b3726c 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode): raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode): f"must equal the global duration ({duration}s)." ) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), - mode="pro" if resolution == "1080p" else "std", + mode=mode, multi_shot=multi_shot, multi_prompt=multi_prompt_list, shot_type="customize" if multi_shot else None, @@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): optional=True, tooltip="Up to 6 additional reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): "reference_images", tooltip="Up to 7 reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode): raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode): image_list: list[OmniParamImage] = [] for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode): IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16", "1:1"], @@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode): ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode): start_frame: Input.Image | None = None, ) -> IO.NodeOutput: _ = seed - mode = "pro" if model["resolution"] == "1080p" else "std" + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" custom_multi_shot = False if multi_shot["multi_shot"] == "disabled": shot_type = None @@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode): cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), ], ), ], @@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode): ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode): validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), @@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): image=image_url, image_tail=image_tail_url, prompt=prompt, - mode="pro" if model["resolution"] == "1080p" else "std", + mode=mode, duration=str(duration), sound="on" if generate_audio else "off", ), @@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) From 6fbb6b6f49ccd1d7d336368540b71248e3701dde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 23 Apr 2026 21:13:17 +0300 Subject: [PATCH 23/35] Fix LTXV Reference Audio node (#13531) --- comfy_extras/nodes_lt.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e8744..19d8a387f 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchaudio import comfy.model_management import comfy.model_sampling import comfy.samplers @@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode): @classmethod def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: # Encode reference audio to latents and patchify - audio_latents = audio_vae.encode(reference_audio) + sample_rate = reference_audio["sample_rate"] + vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate) + else: + waveform = reference_audio["waveform"] + + audio_latents = audio_vae.encode(waveform.movedim(1, -1)) b, c, t, f = audio_latents.shape ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_audio = {"tokens": ref_tokens} From ef8f3cbcdc214b3b1647d3ad845aae99a3bf95d1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 24 Apr 2026 04:14:13 +1000 Subject: [PATCH 24/35] comfy-aimdo 0.2.14: Hotfix async allocator estimations (#13534) This was doing an over-estimate of VRAM used by the async allocator when lots of little small tensors were in play. Also change the versioning scheme to == so we can roll forward aimdo without worrying about stable regressions downstream in comfyUI core. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 419124f48..7a2e4e0a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.2.14 requests simpleeval>=1.0.0 blake3 From 084e08c6e2d1c2c450fb74ec4f2ac39c31ea69bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 23 Apr 2026 21:14:42 +0300 Subject: [PATCH 25/35] Disable sageattention for SAM3 (#13529) Causes Nans --- comfy/ldm/sam3/detector.py | 2 +- comfy/ldm/sam3/sam.py | 4 ++-- comfy/ldm/sam3/tracker.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py index 6ae919a79..12d3a01ab 100644 --- a/comfy/ldm/sam3/detector.py +++ b/comfy/ldm/sam3/detector.py @@ -54,7 +54,7 @@ class SplitMHA(nn.Module): if mask is not None and mask.ndim == 2: mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast dtype = q.dtype # manual_cast may produce mixed dtypes - out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False) return self.out_proj(out) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py index 272781d45..75cb457cf 100644 --- a/comfy/ldm/sam3/sam.py +++ b/comfy/ldm/sam3/sam.py @@ -40,7 +40,7 @@ class SAMAttention(nn.Module): q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) - return self.out_proj(optimized_attention(q, k, v, self.num_heads)) + return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) class TwoWayAttentionBlock(nn.Module): @@ -179,7 +179,7 @@ class Attention(nn.Module): q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) if self.use_rope and freqs_cis is not None: q, k = apply_rope(q, k, freqs_cis) - return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False)) class Block(nn.Module): diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py index 6ff6369d1..8f7481003 100644 --- a/comfy/ldm/sam3/tracker.py +++ b/comfy/ldm/sam3/tracker.py @@ -364,7 +364,7 @@ class SplitAttn(nn.Module): v = self.v_proj(v) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - out = optimized_attention(q, k, v, self.num_heads) + out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False) return self.out_proj(out) @@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module): v = self.self_attn_v_proj(normed) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) - x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # Decoupled cross-attention: fuse image and memory projections normed = self.norm2(x) @@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module): v = self.cross_attn_v_proj(memory) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # FFN x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) From 2327fa1c908602076318e5ffca02a45d4a7e6af8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:20:24 +1000 Subject: [PATCH 26/35] execution: Add anti-cycle validation (#13169) Currently if the graph contains a cycle, the just inifitiate recursions, hits a catch all then throws a generic error against the output node that seeded the validation. Instead, fail the offending cycling mode chain and handlng it as an error in its own right. Co-authored-by: guill --- execution.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index 5e02dffb2..e15eb4bda 100644 --- a/execution.py +++ b/execution.py @@ -811,11 +811,30 @@ class PromptExecutor: self._notify_prompt_lifecycle("end", prompt_id) -async def validate_inputs(prompt_id, prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): + if visiting is None: + visiting = [] + unique_id = item if unique_id in validated: return validated[unique_id] + if unique_id in visiting: + cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + for node_id in cycle_nodes: + validated[node_id] = (False, [{ + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + } + }], node_id) + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue try: - r = await validate_inputs(prompt_id, prompt, o_id, validated) + visiting.append(unique_id) + try: + r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + finally: + visiting.pop() if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if len(errors) > 0 or valid is not True: - ret = (False, errors, unique_id) - else: - ret = (True, [], unique_id) + ret = validated.get(unique_id, (True, [], unique_id)) + # Recursive cycle detection may have already populated an error on us. Join it. + ret = ( + ret[0] and valid is True and not errors, + ret[1] + [error for error in errors if error not in ret[1]], + unique_id, + ) validated[unique_id] = ret return ret From 47ccecaee009cce148e8c2a5bdc2ecb302cc52ee Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Fri, 24 Apr 2026 07:56:13 +0800 Subject: [PATCH 27/35] chore: update workflow templates to v0.9.62 (#13539) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7a2e4e0a2..346ce4b76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.14 -comfyui-workflow-templates==0.9.61 +comfyui-workflow-templates==0.9.62 comfyui-embedded-docs==0.4.4 torch torchsde From c5d9edacd0d92cf2b6d9f82e6b60d6250c269e9e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:19:00 -0700 Subject: [PATCH 28/35] Print more tensor values in the preview any node. (#13544) --- comfy_extras/nodes_preview_any.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 0a1558f2b..17e25d514 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -1,5 +1,6 @@ import json from comfy.comfy_types.node_typing import IO +import torch # Preview Any - original implement from # https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py @@ -19,6 +20,7 @@ class PreviewAny(): SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): + torch.set_printoptions(edgeitems=6) value = 'None' if isinstance(source, str): value = source @@ -33,6 +35,7 @@ class PreviewAny(): except Exception: value = 'source exists, but could not be serialized.' + torch.set_printoptions() return {"ui": {"text": (value,)}, "result": (value,)} NODE_CLASS_MAPPINGS = { From 00d2f4047db3de6c14f965f6f34354d5ed5d0ccc Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 23 Apr 2026 23:42:22 -0400 Subject: [PATCH 29/35] fix: use textureSize instead of u_resolution for texel size in blur/sharpen shaders (#13347) * fix: use textureSize instead of u_resolution for texel size in blur/sharpen shaders * fix: remove unused u_resolution uniform and fix Glow shader texelSize --------- Co-authored-by: guill --- blueprints/.glsl/Glow_30.frag | 3 +-- blueprints/.glsl/Image_Blur_1.frag | 3 +-- blueprints/.glsl/Sharpen_23.frag | 3 +-- blueprints/.glsl/Unsharp_Mask_26.frag | 3 +-- blueprints/Glow.json | 2 +- blueprints/Image Blur.json | 2 +- blueprints/Sharpen.json | 2 +- blueprints/Unsharp Mask.json | 2 +- 8 files changed, 8 insertions(+), 12 deletions(-) diff --git a/blueprints/.glsl/Glow_30.frag b/blueprints/.glsl/Glow_30.frag index 0ee152628..f3c85a212 100644 --- a/blueprints/.glsl/Glow_30.frag +++ b/blueprints/.glsl/Glow_30.frag @@ -2,7 +2,6 @@ precision mediump float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform int u_int0; // Blend mode uniform int u_int1; // Color tint uniform float u_float0; // Intensity @@ -75,7 +74,7 @@ void main() { float t0 = threshold - 0.15; float t1 = threshold + 0.15; - vec2 texelSize = 1.0 / u_resolution; + vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0)); float radius2 = radius * radius; float sampleScale = clamp(radius * 0.75, 0.35, 1.0); diff --git a/blueprints/.glsl/Image_Blur_1.frag b/blueprints/.glsl/Image_Blur_1.frag index 83238111d..1819e1695 100644 --- a/blueprints/.glsl/Image_Blur_1.frag +++ b/blueprints/.glsl/Image_Blur_1.frag @@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12; const float RADIAL_STRENGTH = 0.0003; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL) uniform float u_float0; // Blur radius/amount uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical) @@ -25,7 +24,7 @@ float gaussian(float x, float sigma) { } void main() { - vec2 texelSize = 1.0 / u_resolution; + vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0)); float radius = max(u_float0, 0.0); // Radial (angular) blur - single pass, doesn't use separable diff --git a/blueprints/.glsl/Sharpen_23.frag b/blueprints/.glsl/Sharpen_23.frag index c03f94b66..e7463a329 100644 --- a/blueprints/.glsl/Sharpen_23.frag +++ b/blueprints/.glsl/Sharpen_23.frag @@ -2,14 +2,13 @@ precision highp float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0 in vec2 v_texCoord; layout(location = 0) out vec4 fragColor0; void main() { - vec2 texel = 1.0 / u_resolution; + vec2 texel = 1.0 / vec2(textureSize(u_image0, 0)); // Sample center and neighbors vec4 center = texture(u_image0, v_texCoord); diff --git a/blueprints/.glsl/Unsharp_Mask_26.frag b/blueprints/.glsl/Unsharp_Mask_26.frag index f5990cb4a..d968c9c03 100644 --- a/blueprints/.glsl/Unsharp_Mask_26.frag +++ b/blueprints/.glsl/Unsharp_Mask_26.frag @@ -2,7 +2,6 @@ precision highp float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5 uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen @@ -19,7 +18,7 @@ float getLuminance(vec3 color) { } void main() { - vec2 texel = 1.0 / u_resolution; + vec2 texel = 1.0 / vec2(textureSize(u_image0, 0)); float radius = max(u_float1, 0.5); float amount = u_float0; float threshold = u_float2; diff --git a/blueprints/Glow.json b/blueprints/Glow.json index 8c690fc68..1dafb2d35 100644 --- a/blueprints/Glow.json +++ b/blueprints/Glow.json @@ -268,7 +268,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", + "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", "from_input" ] }, diff --git a/blueprints/Image Blur.json b/blueprints/Image Blur.json index b1d449e32..3c7a784b0 100644 --- a/blueprints/Image Blur.json +++ b/blueprints/Image Blur.json @@ -331,7 +331,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", + "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", "from_input" ] } diff --git a/blueprints/Sharpen.json b/blueprints/Sharpen.json index bb79f61fc..f332400fd 100644 --- a/blueprints/Sharpen.json +++ b/blueprints/Sharpen.json @@ -267,7 +267,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", + "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input" ] } diff --git a/blueprints/Unsharp Mask.json b/blueprints/Unsharp Mask.json index b673eb703..137acaa43 100644 --- a/blueprints/Unsharp Mask.json +++ b/blueprints/Unsharp Mask.json @@ -383,7 +383,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", + "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", "from_input" ] } From 2e0503780d8cd4285d2b883ba5ba1ea152eb194e Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 23 Apr 2026 23:51:34 -0400 Subject: [PATCH 30/35] range type (#13322) Co-authored-by: guill --- comfy_api/input/__init__.py | 2 + comfy_api/latest/_input/__init__.py | 2 + comfy_api/latest/_input/range_types.py | 70 ++++++++++++++++++++++++++ comfy_api/latest/_io.py | 38 ++++++++++++++ 4 files changed, 112 insertions(+) create mode 100644 comfy_api/latest/_input/range_types.py diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 16d4acfd1..dc33533cc 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -9,6 +9,7 @@ from comfy_api.latest._input import ( CurveInput, MonotoneCubicCurve, LinearCurve, + RangeInput, ) __all__ = [ @@ -21,4 +22,5 @@ __all__ = [ "CurveInput", "MonotoneCubicCurve", "LinearCurve", + "RangeInput", ] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 05cd3d40a..f0229717e 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,5 +1,6 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve +from .range_types import RangeInput from .video_types import VideoInput __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "CurveInput", "MonotoneCubicCurve", "LinearCurve", + "RangeInput", ] diff --git a/comfy_api/latest/_input/range_types.py b/comfy_api/latest/_input/range_types.py new file mode 100644 index 000000000..f4c5cb290 --- /dev/null +++ b/comfy_api/latest/_input/range_types.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +import math +import numpy as np + +logger = logging.getLogger(__name__) + + +class RangeInput: + """Represents a levels/range adjustment: input range [min, max] with + optional midpoint (gamma control). + + Generates a 1D LUT identical to GIMP's levels mapping: + 1. Normalize input to [0, 1] using [min, max] + 2. Apply gamma correction: pow(value, 1/gamma) + 3. Clamp to [0, 1] + + The midpoint field is a position in [0, 1] representing where the + midtone falls within [min, max]. It maps to gamma via: + gamma = -log2(midpoint) + So midpoint=0.5 → gamma=1.0 (linear). + """ + + def __init__(self, min_val: float, max_val: float, midpoint: float | None = None): + self.min_val = min_val + self.max_val = max_val + self.midpoint = midpoint + + @staticmethod + def from_raw(data) -> RangeInput: + if isinstance(data, RangeInput): + return data + if isinstance(data, dict): + return RangeInput( + min_val=float(data.get("min", 0.0)), + max_val=float(data.get("max", 1.0)), + midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None, + ) + raise TypeError(f"Cannot convert {type(data)} to RangeInput") + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table mapping [0, 1] input through this + levels adjustment. + + The LUT maps normalized input values (0..1) to output values (0..1), + matching the GIMP levels formula. + """ + xs = np.linspace(0.0, 1.0, size, dtype=np.float64) + + in_range = self.max_val - self.min_val + if abs(in_range) < 1e-10: + return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64) + + # Normalize: map [min, max] → [0, 1] + result = (xs - self.min_val) / in_range + result = np.clip(result, 0.0, 1.0) + + # Gamma correction from midpoint + if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5: + gamma = max(-math.log2(self.midpoint), 0.001) + inv_gamma = 1.0 / gamma + mask = result > 0 + result[mask] = np.power(result[mask], inv_gamma) + + return result + + def __repr__(self) -> str: + mid = f", midpoint={self.midpoint}" if self.midpoint is not None else "" + return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})" diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fdeffea2d..4942ed46c 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO): Type = list[int] +@comfytype(io_type="RANGE") +class Range(ComfyTypeIO): + from comfy_api.input import RangeInput + if TYPE_CHECKING: + Type = RangeInput + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, + display: str=None, + gradient_stops: list=None, + show_midpoint: bool=None, + midpoint_scale: str=None, + value_min: float=None, + value_max: float=None, + advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = {"min": 0.0, "max": 1.0} + self.display = display + self.gradient_stops = gradient_stops + self.show_midpoint = show_midpoint + self.midpoint_scale = midpoint_scale + self.value_min = value_min + self.value_max = value_max + + def as_dict(self): + return super().as_dict() | prune_dict({ + "display": self.display, + "gradient_stops": self.gradient_stops, + "show_midpoint": self.show_midpoint, + "midpoint_scale": self.midpoint_scale, + "value_min": self.value_min, + "value_max": self.value_max, + }) + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2276,5 +2313,6 @@ __all__ = [ "BoundingBox", "Curve", "Histogram", + "Range", "NodeReplace", ] From 443074eee92fb0f41b38b83404010069fdb25860 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Thu, 23 Apr 2026 21:00:25 -0700 Subject: [PATCH 31/35] Add OpenAPI 3.1 specification for ComfyUI API (#13397) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add OpenAPI 3.1 specification for ComfyUI API Adds a comprehensive OpenAPI 3.1 spec documenting all HTTP endpoints exposed by ComfyUI's server, including prompt execution, queue management, file uploads, userdata, settings, system stats, object info, assets, and internal routes. The spec was validated against the source code with adversarial review from multiple models, and passes Spectral linting with zero errors. Also removes openapi.yaml from .gitignore so the spec is tracked. * Mark /api/history endpoints as deprecated Address Jacob's review feedback on PR #13397 by explicitly marking the three /api/history operations as deprecated in the OpenAPI spec: * GET /api/history -> superseded by GET /api/jobs * POST /api/history -> superseded by /api/jobs management * GET /api/history/{prompt_id} -> superseded by GET /api/jobs/{job_id} Each operation gains deprecated: true plus a description that names the replacement. A formal sunset timeline (RFC 8594 Deprecation and RFC 8553 Sunset headers, minimum-runway policy) is being defined separately and will be applied as a follow-up. * Address Spectral lint findings in openapi.yaml - Add operation descriptions to 52 endpoints (prompt, queue, upload, view, models, userdata, settings, assets, internal, etc.) - Add schema descriptions to 22 component schemas - Add parameter descriptions to 8 path parameters that were missing them - Remove 6 unused component schemas: TaskOutput, EmbeddingsResponse, ExtensionsResponse, LogRawResponse, UserInfo, UserDataFullInfo No wire/shape changes. Reduces Spectral findings from 92 to 4. The remaining 4 are real issues (WebSocket 101 on /ws, loose error schema, and two snake_case warnings on real wire field names) and are worth addressing separately. * fix(openapi): address jtreminio oneOf review on /api/userdata Restructure the UserData response schemas to address the review feedback on the `oneOf` without a discriminator, and fix two accuracy bugs found while doing it. Changes - GET /api/userdata response: extract the inline `oneOf` to a named schema (`ListUserdataResponse`) and add the missing third variant returned when `split=true` and `full_info=false` (array of `[relative_path, ...path_components]`). Previously only two of the three actual server response shapes were described. - UserDataResponse (POST endpoints): correct the description — this schema is a single item, not a list — and point at the canonical `GetUserDataResponseFullFile` schema instead of the duplicate `UserDataResponseFull`. Also removes the malformed blank line in `UserDataResponseShort`. - Delete the now-unused `UserDataResponseFull` and `UserDataResponseShort` schemas (replaced by reuse of `GetUserDataResponseFullFile` and an inline string variant). - Add an `x-variant-selector` vendor extension to both `oneOf` sites documenting which query-parameter combination selects which branch, since a true OpenAPI `discriminator` is not applicable (the variants are type-disjoint and the selector lives in the request, not the response body). This keeps the shapes the server actually emits (no wire-breaking change) while making the selection rule explicit for SDK generators and readers. --------- Co-authored-by: guill --- .gitignore | 1 - openapi.yaml | 3231 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 3231 insertions(+), 1 deletion(-) create mode 100644 openapi.yaml diff --git a/.gitignore b/.gitignore index 2700ad5c2..0ab4ba75e 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,5 @@ venv*/ *.log web_custom_versions/ .DS_Store -openapi.yaml filtered-openapi.yaml uv.lock diff --git a/openapi.yaml b/openapi.yaml new file mode 100644 index 000000000..77d0e2318 --- /dev/null +++ b/openapi.yaml @@ -0,0 +1,3231 @@ +openapi: 3.1.0 +info: + title: ComfyUI API + description: | + API for ComfyUI - A powerful and modular stable diffusion GUI and backend. + + This API allows you to interact with ComfyUI programmatically, including: + - Submitting and managing workflow executions + - Querying node/object information + - Uploading and viewing files + - Managing user settings and data + - Asset management (feature-gated) + + ## Dual-path routing + Every route registered via `self.routes` in the ComfyUI server is available at + both its bare path (e.g. `/prompt`) and an `/api`-prefixed path (e.g. `/api/prompt`). + This spec uses the `/api`-prefixed versions as canonical. + + ## Multi-user mode + When ComfyUI is started with `--multi-user`, the `Comfy-User` header identifies + the active user for settings, userdata, and history isolation. This is **not** a + security mechanism — it is an organisational convenience with no authentication + or authorisation behind it. + version: 1.0.0 + license: + name: GNU General Public License v3.0 + url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE + +servers: + - url: / + description: Default ComfyUI server (typically http://127.0.0.1:8188) + +tags: + - name: prompt + description: Workflow submission and prompt info + - name: queue + description: Queue inspection and management + - name: history + description: Execution history + - name: upload + description: File upload endpoints + - name: view + description: File viewing / download + - name: system + description: System stats and feature flags + - name: node + description: Node / object_info definitions + - name: model + description: Model folder and file listing + - name: user + description: User management (multi-user mode) + - name: userdata + description: Per-user file storage + - name: settings + description: Per-user settings + - name: extensions + description: Frontend extension JS files + - name: subgraph + description: Global subgraph blueprints + - name: internal + description: Internal / debug endpoints + - name: assets + description: Asset management (feature-gated behind enable-assets) + +paths: + # --------------------------------------------------------------------------- + # WebSocket + # --------------------------------------------------------------------------- + /ws: + get: + operationId: connectWebSocket + tags: [system] + summary: WebSocket connection for real-time updates + description: | + Upgrades to a WebSocket connection that streams execution progress, + node status, and output messages. The server sends an initial `status` + message with the session ID (SID) on connect. + + ## Message types (server → client) + The server sends JSON messages with a `type` field. See the + `x-websocket-messages` list below for the schema of each message type. + parameters: + - name: clientId + in: query + required: false + schema: + type: string + description: Client identifier. If omitted the server assigns one. + responses: + "101": + description: WebSocket upgrade successful + x-websocket-messages: + - type: status + schema: + $ref: "#/components/schemas/StatusWsMessage" + - type: progress + schema: + $ref: "#/components/schemas/ProgressWsMessage" + - type: progress_text + schema: + $ref: "#/components/schemas/ProgressTextWsMessage" + - type: progress_state + schema: + $ref: "#/components/schemas/ProgressStateWsMessage" + - type: executing + schema: + $ref: "#/components/schemas/ExecutingWsMessage" + - type: executed + schema: + $ref: "#/components/schemas/ExecutedWsMessage" + - type: execution_start + schema: + $ref: "#/components/schemas/ExecutionStartWsMessage" + - type: execution_success + schema: + $ref: "#/components/schemas/ExecutionSuccessWsMessage" + - type: execution_cached + schema: + $ref: "#/components/schemas/ExecutionCachedWsMessage" + - type: execution_interrupted + schema: + $ref: "#/components/schemas/ExecutionInterruptedWsMessage" + - type: execution_error + schema: + $ref: "#/components/schemas/ExecutionErrorWsMessage" + - type: logs + schema: + $ref: "#/components/schemas/LogsWsMessage" + - type: notification + schema: + $ref: "#/components/schemas/NotificationWsMessage" + - type: feature_flags + schema: + $ref: "#/components/schemas/FeatureFlagsWsMessage" + - type: asset_download + schema: + $ref: "#/components/schemas/AssetDownloadWsMessage" + - type: asset_export + schema: + $ref: "#/components/schemas/AssetExportWsMessage" + + # --------------------------------------------------------------------------- + # Prompt + # --------------------------------------------------------------------------- + /api/prompt: + get: + operationId: getPromptInfo + tags: [prompt] + summary: Get queue status + description: Returns how many items remain in the execution queue. + responses: + "200": + description: Queue info + content: + application/json: + schema: + $ref: "#/components/schemas/PromptInfo" + post: + operationId: executePrompt + tags: [prompt] + summary: Submit a workflow for execution + description: Submits a workflow for execution. The server validates the graph, assigns a `prompt_id`, and enqueues it. Clients listen on `/ws` for execution progress and output messages. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/PromptRequest" + responses: + "200": + description: Prompt accepted + content: + application/json: + schema: + $ref: "#/components/schemas/PromptResponse" + "400": + description: Validation or node errors + content: + application/json: + schema: + $ref: "#/components/schemas/PromptErrorResponse" + + # --------------------------------------------------------------------------- + # Queue + # --------------------------------------------------------------------------- + /api/queue: + get: + operationId: getQueue + tags: [queue] + summary: Get running and pending queue items + description: Returns the server's current execution queue, split into the currently-running prompt and the list of pending prompts. + responses: + "200": + description: Queue contents + content: + application/json: + schema: + $ref: "#/components/schemas/QueueInfo" + post: + operationId: manageQueue + tags: [queue] + summary: Clear or delete items from the queue + description: Mutates the execution queue. Supports clearing all queued prompts or deleting individual prompts by ID. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/QueueManageRequest" + responses: + "200": + description: Queue updated + + /api/interrupt: + post: + operationId: interruptExecution + tags: [queue] + summary: Interrupt current execution + description: Interrupts the prompt that is currently executing. The next queued prompt (if any) will start immediately after. + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + prompt_id: + type: string + format: uuid + description: "If provided, only interrupts this specific running prompt. Otherwise interrupts all." + responses: + "200": + description: Interrupt signal sent + + /api/free: + post: + operationId: freeMemory + tags: [queue] + summary: Free GPU memory and/or unload models + description: Frees GPU memory by unloading models and/or freeing the resident model cache, controlled by the request flags. + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + unload_models: + type: boolean + description: Unload all models from VRAM/RAM + free_memory: + type: boolean + description: Run garbage collection and free cached memory + responses: + "200": + description: Memory freed + + # --------------------------------------------------------------------------- + # Jobs + # --------------------------------------------------------------------------- + /api/jobs: + get: + operationId: listJobs + tags: [queue] + summary: List jobs with filtering and pagination + description: Returns a paginated list of completed prompt executions, newest first. + parameters: + - name: status + in: query + schema: + type: string + description: Filter by job status + - name: workflow_id + in: query + schema: + type: string + description: Filter by workflow ID + - name: sort_by + in: query + schema: + type: string + description: Field to sort by + - name: sort_order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + - name: limit + in: query + schema: + type: integer + description: Maximum number of results (default is unlimited/None) + - name: offset + in: query + schema: + type: integer + default: 0 + description: Pagination offset + responses: + "200": + description: Jobs list + content: + application/json: + schema: + type: object + properties: + jobs: + type: array + items: + $ref: "#/components/schemas/JobEntry" + pagination: + $ref: "#/components/schemas/PaginationInfo" + + /api/jobs/{job_id}: + get: + operationId: getJob + tags: [queue] + summary: Get a single job by ID + description: Returns the full record for a single completed prompt execution, including its outputs, status, and metadata. + parameters: + - name: job_id + in: path + description: The job (prompt) ID to fetch. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Job detail + content: + application/json: + schema: + $ref: "#/components/schemas/JobDetailResponse" + "404": + description: Job not found + + # --------------------------------------------------------------------------- + # History + # --------------------------------------------------------------------------- + /api/history: + get: + operationId: getHistory + tags: [history] + summary: Get execution history + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs`, which returns the same + execution records in a paginated, filterable format. Planned for removal + no earlier than a future major release; sunset timeline TBD. + + Returns a dictionary keyed by prompt_id. Each value is a HistoryEntry + containing prompt metadata, outputs, status, and node meta. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: max_items + in: query + schema: + type: integer + description: Maximum number of history entries to return + - name: offset + in: query + schema: + type: integer + description: Pagination offset (number of entries to skip) + responses: + "200": + description: History dictionary keyed by prompt_id + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/HistoryEntry" + post: + operationId: manageHistory + tags: [history] + summary: Clear or delete history entries + deprecated: true + description: | + **Deprecated.** Superseded by the forthcoming job-management endpoints + under `/api/jobs`. Planned for removal no earlier than a future major + release; sunset timeline TBD. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/HistoryManageRequest" + responses: + "200": + description: History updated + + /api/history/{prompt_id}: + get: + operationId: getHistoryByPromptId + tags: [history] + summary: Get history for a specific prompt + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs/{job_id}`, which returns + the same execution record. Planned for removal no earlier than a future + major release; sunset timeline TBD. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: prompt_id + in: path + description: The prompt ID to fetch history for. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Single-entry history dictionary. Returns an empty object `{}` if the prompt_id is not found. + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/HistoryEntry" + + # --------------------------------------------------------------------------- + # Upload + # --------------------------------------------------------------------------- + /api/upload/image: + post: + operationId: uploadImage + tags: [upload] + summary: Upload an image file + description: Uploads an image file into one of the input/output/temp directories so it can be referenced by workflow nodes. + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - image + properties: + image: + type: string + format: binary + description: Image file to upload + type: + type: string + enum: [input, temp, output] + default: input + description: Target directory type + overwrite: + type: string + description: 'Set to "true" to overwrite existing files' + subfolder: + type: string + description: Subfolder within the target directory + responses: + "200": + description: Upload result + content: + application/json: + schema: + $ref: "#/components/schemas/UploadResult" + "400": + description: No file provided or invalid request + + /api/upload/mask: + post: + operationId: uploadMask + tags: [upload] + summary: Upload a mask image + description: Uploads a mask image associated with a previously-uploaded reference image. + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - image + - original_ref + properties: + image: + type: string + format: binary + description: Mask image (alpha channel is used) + original_ref: + type: object + description: Reference to the original image file + required: + - filename + properties: + filename: + type: string + description: Filename of the original image + additionalProperties: true + type: + type: string + enum: [input, temp, output] + default: input + description: Target directory type + overwrite: + type: string + description: 'Set to "true" to overwrite existing files' + subfolder: + type: string + description: Subfolder within the target directory + responses: + "200": + description: Upload result + content: + application/json: + schema: + $ref: "#/components/schemas/UploadResult" + "400": + description: No file provided or invalid request + + # --------------------------------------------------------------------------- + # View + # --------------------------------------------------------------------------- + /api/view: + get: + operationId: viewFile + tags: [view] + summary: View or download a file + description: Serves a file (image, audio, or video) from the input/output/temp directory identified by the query parameters. + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Name of the file to view + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + default: output + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + - name: preview + in: query + schema: + type: string + description: Preview format hint (e.g. "webp;90") + - name: channel + in: query + schema: + type: string + enum: [rgba, rgb, a] + description: Channel extraction mode + responses: + "200": + description: File content + content: + image/*: + schema: + type: string + format: binary + video/*: + schema: + type: string + format: binary + audio/*: + schema: + type: string + format: binary + application/octet-stream: + schema: + type: string + format: binary + "404": + description: File not found + + /api/view_metadata/{folder_name}: + get: + operationId: viewMetadata + tags: [view] + summary: Get metadata for a file (e.g. safetensors header) + description: Returns embedded metadata parsed from a file in the given folder — for example, the header of a safetensors model. + parameters: + - name: folder_name + in: path + required: true + schema: + type: string + description: Folder type (output, input, temp, etc.) + - name: filename + in: query + required: true + schema: + type: string + description: Filename to read metadata from + responses: + "200": + description: File metadata + content: + application/json: + schema: + type: object + additionalProperties: true + "404": + description: File or metadata not found + + # --------------------------------------------------------------------------- + # System + # --------------------------------------------------------------------------- + /api/system_stats: + get: + operationId: getSystemStats + tags: [system] + summary: Get system statistics + description: Returns hardware, Python, VRAM, and runtime statistics for the running ComfyUI process. + responses: + "200": + description: System stats + content: + application/json: + schema: + $ref: "#/components/schemas/SystemStatsResponse" + + /api/features: + get: + operationId: getFeatures + tags: [system] + summary: Get enabled feature flags + description: Returns a dictionary of feature flag names to their enabled state. + responses: + "200": + description: Feature flags + content: + application/json: + schema: + type: object + additionalProperties: + type: boolean + + # --------------------------------------------------------------------------- + # Node / Object Info + # --------------------------------------------------------------------------- + /api/object_info: + get: + operationId: getObjectInfo + tags: [node] + summary: Get all node definitions + description: | + Returns a dictionary of every registered node class, keyed by class name. + Each value is a NodeInfo object describing inputs, outputs, category, etc. + responses: + "200": + description: All node definitions + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/NodeInfo" + + /api/object_info/{node_class}: + get: + operationId: getObjectInfoByClass + tags: [node] + summary: Get a single node definition + description: Returns the `NodeInfo` definition for a single registered node class. + parameters: + - name: node_class + in: path + required: true + schema: + type: string + description: Node class name (e.g. "KSampler") + responses: + "200": + description: Single node definition + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/NodeInfo" + "404": + description: Node class not found + + /api/embeddings: + get: + operationId: getEmbeddings + tags: [node] + summary: List available embedding names + description: Returns the list of text-encoder embeddings available on disk. + responses: + "200": + description: Embedding names + content: + application/json: + schema: + type: array + items: + type: string + + # --------------------------------------------------------------------------- + # Models + # --------------------------------------------------------------------------- + /api/models: + get: + operationId: getModelTypes + tags: [model] + summary: List model folder type names + description: Returns an array of model type names (e.g. checkpoints, loras, vae). + responses: + "200": + description: Model type names + content: + application/json: + schema: + type: array + items: + type: string + + /api/models/{folder}: + get: + operationId: getModelsByFolder + tags: [model] + summary: List model filenames in a folder + description: Returns the names of model files in the given folder. This endpoint predates `/api/experiment/models/{folder}` and returns names only — prefer the experiment endpoint for new integrations. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + responses: + "200": + description: Model filenames + content: + application/json: + schema: + type: array + items: + type: string + "404": + description: Unknown folder type + + /api/experiment/models: + get: + operationId: getExperimentModels + tags: [model] + summary: List model folders with paths + description: Returns an array of model folder objects with name and folder paths. + responses: + "200": + description: Model folders + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ModelFolder" + + /api/experiment/models/{folder}: + get: + operationId: getExperimentModelsByFolder + tags: [model] + summary: List model files with metadata + description: Returns the model files in the given folder with richer metadata (path index, mtime, size) than the legacy `/api/models/{folder}` endpoint. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + responses: + "200": + description: Model files with metadata + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ModelFile" + "404": + description: Unknown folder type + + /api/experiment/models/preview/{folder}/{path_index}/{filename}: + get: + operationId: getModelPreview + tags: [model] + summary: Get model preview image + description: Returns the preview image associated with a model file, if one exists alongside the model on disk. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + - name: path_index + in: path + required: true + schema: + type: integer + description: Path index within the folder + - name: filename + in: path + required: true + schema: + type: string + description: Model filename + responses: + "200": + description: Preview image (WebP) + content: + image/webp: + schema: + type: string + format: binary + "404": + description: Preview not found + + # --------------------------------------------------------------------------- + # Users + # --------------------------------------------------------------------------- + /api/users: + get: + operationId: getUsers + tags: [user] + summary: Get user storage info + description: | + Returns user storage configuration. In single-user mode returns + `{"storage": "server", "migrated": true/false}`. In multi-user mode + returns `{"storage": "server", "users": {"user_id": "user_dir", ...}}`. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + responses: + "200": + description: User info + content: + application/json: + schema: + type: object + properties: + storage: + type: string + description: Storage backend type (always "server") + migrated: + type: boolean + description: Whether migration from browser storage is complete (single-user) + users: + type: object + additionalProperties: + type: string + description: Map of user_id to directory name (multi-user) + post: + operationId: createUser + tags: [user] + summary: Create a new user (multi-user mode) + description: Creates a new user entry. Only meaningful when ComfyUI is running in multi-user mode. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - username + properties: + username: + type: string + description: Username for the new user + responses: + "200": + description: Created user ID + content: + application/json: + schema: + type: string + description: The generated user_id + "400": + description: Username already exists or invalid + + # --------------------------------------------------------------------------- + # Userdata + # --------------------------------------------------------------------------- + /api/userdata: + get: + operationId: listUserdata + tags: [userdata] + summary: List files in a userdata directory + description: Lists files in the authenticated user's data directory. Returns either filename strings or full objects depending on the `full_info` query parameter. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: dir + in: query + required: true + schema: + type: string + description: Directory path relative to the user's data folder + - name: recurse + in: query + schema: + type: boolean + description: Recurse into subdirectories + - name: full_info + in: query + schema: + type: boolean + description: Return full file info objects instead of just names + - name: split + in: query + schema: + type: boolean + description: Split paths into directory components + responses: + "200": + description: File listing + content: + application/json: + schema: + $ref: "#/components/schemas/ListUserdataResponse" + "404": + description: Directory not found + + /api/v2/userdata: + get: + operationId: listUserdataV2 + tags: [userdata] + summary: List files in userdata (v2 format) + description: Lists files in the authenticated user's data directory using the v2 response shape, which always returns full objects. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: path + in: query + schema: + type: string + description: Directory path relative to user data root + responses: + "200": + description: File listing with metadata + content: + application/json: + schema: + type: array + items: + type: object + properties: + name: + type: string + path: + type: string + type: + type: string + enum: [file, directory] + size: + type: integer + modified: + type: number + description: Unix timestamp + + /api/userdata/{file}: + get: + operationId: getUserdataFile + tags: [userdata] + summary: Read a userdata file + description: Reads the contents of a file from the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "200": + description: File content + content: + application/octet-stream: + schema: + type: string + format: binary + "404": + description: File not found + post: + operationId: writeUserdataFile + tags: [userdata] + summary: Write or create a userdata file + description: Writes (creates or replaces) a file in the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + - name: overwrite + in: query + schema: + type: boolean + description: Allow overwriting existing files + - name: full_info + in: query + schema: + type: boolean + description: Return full file info in response + requestBody: + required: true + content: + application/octet-stream: + schema: + type: string + format: binary + application/json: + schema: {} + responses: + "200": + description: File written + content: + application/json: + schema: + $ref: "#/components/schemas/UserDataResponse" + "409": + description: File exists and overwrite not set + delete: + operationId: deleteUserdataFile + tags: [userdata] + summary: Delete a userdata file + description: Deletes a file from the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "204": + description: File deleted + "404": + description: File not found + + /api/userdata/{file}/move/{dest}: + post: + operationId: moveUserdataFile + tags: [userdata] + summary: Move or rename a userdata file + description: Renames or moves a file within the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: Source file path + - name: dest + in: path + required: true + schema: + type: string + description: Destination file path + - name: overwrite + in: query + schema: + type: boolean + description: Allow overwriting at destination + - name: full_info + in: query + schema: + type: boolean + description: Return full file info in response + responses: + "200": + description: File moved + content: + application/json: + schema: + $ref: "#/components/schemas/UserDataResponse" + "404": + description: Source file not found + "409": + description: Destination exists and overwrite not set + + # --------------------------------------------------------------------------- + # Settings + # --------------------------------------------------------------------------- + /api/settings: + get: + operationId: getSettings + tags: [settings] + summary: Get all user settings + description: Returns all settings for the authenticated user. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + responses: + "200": + description: Settings object + content: + application/json: + schema: + type: object + additionalProperties: true + post: + operationId: updateSettings + tags: [settings] + summary: Update user settings (partial merge) + description: Replaces the authenticated user's settings with the provided object. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + type: object + additionalProperties: true + description: Partial settings to merge + responses: + "200": + description: Settings updated + + /api/settings/{id}: + get: + operationId: getSetting + tags: [settings] + summary: Get a single setting by key + description: Returns the value of a single setting, identified by key. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: id + in: path + required: true + schema: + type: string + description: Setting key + responses: + "200": + description: Setting value (null if the setting does not exist) + content: + application/json: + schema: + nullable: true + description: The setting value (any JSON type), or null if not set + post: + operationId: updateSetting + tags: [settings] + summary: Set a single setting value + description: Sets the value of a single setting, identified by key. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: id + in: path + required: true + schema: + type: string + description: Setting key + requestBody: + required: true + content: + application/json: + schema: + description: The setting value (any JSON type) + responses: + "200": + description: Setting updated + + # --------------------------------------------------------------------------- + # Extensions / Templates / i18n + # --------------------------------------------------------------------------- + /api/extensions: + get: + operationId: getExtensions + tags: [extensions] + summary: List frontend extension JS file paths + description: Returns the list of frontend extension JS URLs registered by custom nodes, to be loaded by the frontend on startup. + responses: + "200": + description: Array of JS file paths + content: + application/json: + schema: + type: array + items: + type: string + description: Relative path to extension JS file + + /api/workflow_templates: + get: + operationId: getWorkflowTemplates + tags: [extensions] + summary: Get workflow template mappings + description: Returns a map of custom node names to their provided workflow template names. + responses: + "200": + description: Template mappings + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + type: string + description: Map of node pack name to array of template names + + /api/i18n: + get: + operationId: getI18n + tags: [extensions] + summary: Get internationalisation translation strings + description: Returns the URLs of translation files contributed by custom nodes, keyed by locale. + responses: + "200": + description: Translation map + content: + application/json: + schema: + type: object + additionalProperties: true + description: Nested map of locale to translation key-value pairs + + # --------------------------------------------------------------------------- + # Subgraphs + # --------------------------------------------------------------------------- + /api/global_subgraphs: + get: + operationId: getGlobalSubgraphs + tags: [subgraph] + summary: List global subgraph blueprints + description: Returns a dictionary of subgraph IDs to their metadata. + responses: + "200": + description: Subgraph metadata dictionary + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/GlobalSubgraphInfo" + + /api/global_subgraphs/{id}: + get: + operationId: getGlobalSubgraph + tags: [subgraph] + summary: Get a global subgraph with full data + description: Returns the blueprint for a globally-registered subgraph, used by the frontend to materialize the subgraph node. + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Subgraph identifier + responses: + "200": + description: Full subgraph data + content: + application/json: + schema: + $ref: "#/components/schemas/GlobalSubgraphData" + "404": + description: Subgraph not found + + # --------------------------------------------------------------------------- + # Node Replacements + # --------------------------------------------------------------------------- + /api/node_replacements: + get: + operationId: getNodeReplacements + tags: [node] + summary: Get node replacement mappings + description: | + Returns a dictionary mapping deprecated or replaced node class names + to their replacement node information. + responses: + "200": + description: Replacement mappings + content: + application/json: + schema: + type: object + additionalProperties: true + + # --------------------------------------------------------------------------- + # Internal (x-internal: true) + # --------------------------------------------------------------------------- + /internal/logs: + get: + operationId: getInternalLogs + tags: [internal] + summary: Get server logs as text + description: Returns structured ComfyUI log entries from the in-memory log buffer. + x-internal: true + responses: + "200": + description: Log text + content: + text/plain: + schema: + type: string + + /internal/logs/raw: + get: + operationId: getInternalLogsRaw + tags: [internal] + summary: Get raw structured log entries + description: Returns the raw ComfyUI log buffer as text, together with metadata about the current size limit. + x-internal: true + responses: + "200": + description: Structured log data + content: + application/json: + schema: + type: object + properties: + entries: + type: array + items: + type: object + properties: + t: + type: number + description: Timestamp + m: + type: string + description: Message + size: + type: object + properties: + cols: + type: integer + rows: + type: integer + + /internal/logs/subscribe: + patch: + operationId: subscribeToLogs + tags: [internal] + summary: Subscribe or unsubscribe a WebSocket client to log streaming + description: Subscribes or unsubscribes the current client from live log streaming over the WebSocket. + x-internal: true + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - clientId + - enabled + properties: + clientId: + type: string + description: WebSocket client ID + enabled: + type: boolean + description: Enable or disable log streaming for this client + responses: + "200": + description: Subscription updated + + /internal/folder_paths: + get: + operationId: getInternalFolderPaths + tags: [internal] + summary: Get configured folder paths + description: Returns the filesystem paths ComfyUI is configured to load models and other assets from, keyed by folder type. + x-internal: true + responses: + "200": + description: Dictionary of folder type to paths + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + type: array + items: + type: string + description: Map of folder type name to list of [path, ...] entries + + /internal/files/{directory_type}: + get: + operationId: getInternalFiles + tags: [internal] + summary: List files in a directory type + description: Lists the files present in one of ComfyUI's known directories (input, output, or temp). + x-internal: true + parameters: + - name: directory_type + in: path + required: true + schema: + type: string + description: Directory type (e.g. output, input, temp) + responses: + "200": + description: Array of filenames + content: + application/json: + schema: + type: array + items: + type: string + + # --------------------------------------------------------------------------- + # Assets (x-feature-gate: enable-assets) + # --------------------------------------------------------------------------- + /api/assets/hash/{hash}: + head: + operationId: checkAssetByHash + tags: [assets] + summary: Check if an asset with the given hash exists + description: Returns 204 if an asset with the given content hash already exists, 404 otherwise. Used by clients to deduplicate uploads before transferring bytes. + x-feature-gate: enable-assets + parameters: + - name: hash + in: path + required: true + schema: + type: string + description: "Blake3 hash of the asset (e.g. blake3:abc123...)" + responses: + "200": + description: Asset exists + "404": + description: No asset with this hash + + /api/assets: + get: + operationId: listAssets + tags: [assets] + summary: List assets with filtering and pagination + description: Returns a paginated list of assets, optionally filtered by tags, name, or other query parameters. + x-feature-gate: enable-assets + parameters: + - name: limit + in: query + schema: + type: integer + default: 50 + - name: offset + in: query + schema: + type: integer + default: 0 + - name: include_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must have (AND logic) + - name: exclude_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must not have + - name: name_contains + in: query + schema: + type: string + description: Filter assets whose name contains this substring + - name: metadata_filter + in: query + schema: + type: string + description: JSON-encoded metadata key/value filter + - name: sort + in: query + schema: + type: string + description: Field to sort by + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + responses: + "200": + description: Asset list + content: + application/json: + schema: + $ref: "#/components/schemas/ListAssetsResponse" + post: + operationId: createAsset + tags: [assets] + summary: Upload a new asset + description: Uploads a new asset (binary content plus metadata) and registers it in the asset database. + x-feature-gate: enable-assets + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - file + properties: + file: + type: string + format: binary + description: Asset file to upload + name: + type: string + description: Display name for the asset + tags: + type: string + description: Comma-separated tags + user_metadata: + type: string + description: JSON-encoded user metadata + hash: + type: string + description: "Blake3 hash of the file content (e.g. blake3:abc123...)" + mime_type: + type: string + description: MIME type of the file (overrides auto-detected type) + preview_id: + type: string + format: uuid + description: ID of an existing asset to use as the preview image + responses: + "201": + description: Asset created + content: + application/json: + schema: + $ref: "#/components/schemas/AssetCreated" + + /api/assets/from-hash: + post: + operationId: createAssetFromHash + tags: [assets] + summary: Create an asset reference from an existing hash + description: Registers a new asset that references existing content by hash, without re-uploading the bytes. + x-feature-gate: enable-assets + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - hash + - name + properties: + hash: + type: string + description: Blake3 hash of existing content + name: + type: string + description: Display name + tags: + type: array + items: + type: string + user_metadata: + type: object + additionalProperties: true + responses: + "201": + description: Asset created from hash + content: + application/json: + schema: + $ref: "#/components/schemas/AssetCreated" + + /api/assets/{id}: + get: + operationId: getAsset + tags: [assets] + summary: Get asset metadata + description: Returns the metadata for a single asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Asset metadata + content: + application/json: + schema: + $ref: "#/components/schemas/Asset" + "404": + description: Asset not found + put: + operationId: updateAsset + tags: [assets] + summary: Update asset metadata + description: Updates the mutable metadata of an asset (name, tags, etc.). Binary content is immutable. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: New display name for the asset + user_metadata: + type: object + additionalProperties: true + description: Custom user metadata to set + preview_id: + type: string + format: uuid + description: ID of the asset to use as the preview + responses: + "200": + description: Asset updated + content: + application/json: + schema: + $ref: "#/components/schemas/AssetUpdated" + delete: + operationId: deleteAsset + tags: [assets] + summary: Delete an asset + description: Removes an asset entry. Depending on the server configuration, the underlying content may also be deleted. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + - name: delete_content + in: query + schema: + type: boolean + description: Also delete the underlying content file + responses: + "204": + description: Asset deleted + + /api/assets/{id}/content: + get: + operationId: getAssetContent + tags: [assets] + summary: Download asset file content + description: Returns the binary content of an asset. Supports range requests. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Asset file content + content: + application/octet-stream: + schema: + type: string + format: binary + "404": + description: Asset not found + + /api/assets/{id}/tags: + post: + operationId: addAssetTags + tags: [assets] + summary: Add tags to an asset + description: Adds one or more tags to an asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - tags + properties: + tags: + type: array + items: + type: string + responses: + "200": + description: Tags added + content: + application/json: + schema: + $ref: "#/components/schemas/TagsModificationResponse" + delete: + operationId: removeAssetTags + tags: [assets] + summary: Remove tags from an asset + description: Removes one or more tags from an asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - tags + properties: + tags: + type: array + items: + type: string + responses: + "200": + description: Tags removed + content: + application/json: + schema: + $ref: "#/components/schemas/TagsModificationResponse" + + /api/tags: + get: + operationId: listTags + tags: [assets] + summary: List all known tags with counts + description: Returns the list of all tags known to the asset database, with counts. + x-feature-gate: enable-assets + parameters: + - name: limit + in: query + schema: + type: integer + - name: offset + in: query + schema: + type: integer + - name: search + in: query + schema: + type: string + description: Search term for tag name + responses: + "200": + description: Tag list + content: + application/json: + schema: + $ref: "#/components/schemas/ListTagsResponse" + + /api/assets/tags/refine: + get: + operationId: refineAssetTags + tags: [assets] + summary: Get tag counts for assets matching current filters + description: Returns suggested additional tags that would refine a filtered asset query, together with the count of assets each tag would select. + x-feature-gate: enable-assets + parameters: + - name: include_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must have (AND logic) + - name: exclude_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must not have + - name: name_contains + in: query + schema: + type: string + description: Filter assets whose name contains this substring + - name: metadata_filter + in: query + schema: + type: string + description: JSON-encoded metadata key/value filter + - name: limit + in: query + schema: + type: integer + - name: offset + in: query + schema: + type: integer + - name: sort + in: query + schema: + type: string + description: Field to sort by + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + responses: + "200": + description: Tag histogram + content: + application/json: + schema: + $ref: "#/components/schemas/AssetTagHistogramResponse" + + /api/assets/seed: + post: + operationId: seedAssets + tags: [assets] + summary: Trigger asset scan/seed from filesystem + description: Starts a background job that scans the configured directories and registers any assets not yet present in the asset database. + x-feature-gate: enable-assets + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + roots: + type: array + items: + type: string + description: Root folder paths to scan (if omitted, scans all) + responses: + "200": + description: Seed started + content: + application/json: + schema: + type: object + properties: + status: + type: string + + /api/assets/seed/status: + get: + operationId: getAssetSeedStatus + tags: [assets] + summary: Get asset scan progress + description: Returns the progress and status of the most recently-started asset seed job. + x-feature-gate: enable-assets + responses: + "200": + description: Scan progress + content: + application/json: + schema: + type: object + additionalProperties: true + description: Scan progress details (files scanned, total, status, etc.) + + /api/assets/seed/cancel: + post: + operationId: cancelAssetSeed + tags: [assets] + summary: Cancel an in-progress asset scan + description: Requests cancellation of the currently-running asset seed job. + x-feature-gate: enable-assets + responses: + "200": + description: Scan cancelled + content: + application/json: + schema: + type: object + properties: + status: + type: string + + /api/assets/prune: + post: + operationId: pruneAssets + tags: [assets] + summary: Mark assets whose backing files no longer exist on disk + description: Starts a background job that removes asset entries whose underlying content no longer exists on disk. + x-feature-gate: enable-assets + responses: + "200": + description: Prune result + content: + application/json: + schema: + type: object + properties: + status: + type: string + marked: + type: integer + description: Number of assets marked as missing + +components: + parameters: + ComfyUserHeader: + name: Comfy-User + in: header + required: false + schema: + type: string + description: | + Identifies the active user in multi-user mode. Used for settings, + userdata, and history isolation. This is not a security mechanism — + it is an organisational convenience with no authentication behind it. + + schemas: + # ------------------------------------------------------------------- + # Prompt + # ------------------------------------------------------------------- + PromptRequest: + type: object + description: A workflow submission. Wraps the prompt graph plus optional client identifier and extra per-request data. + required: + - prompt + properties: + prompt: + type: object + description: | + The workflow graph to execute. Keys are node IDs (strings); + values are objects with class_type and inputs. + additionalProperties: true + number: + type: number + description: Priority number for the queue (lower numbers have higher priority) + front: + type: boolean + description: If true, adds the prompt to the front of the queue + extra_data: + type: object + description: Extra data associated with the prompt (e.g. extra_pnginfo) + additionalProperties: true + client_id: + type: string + description: WebSocket client ID to receive progress updates + prompt_id: + type: string + format: uuid + description: "Client-supplied prompt ID. Server generates a UUID if omitted." + partial_execution_targets: + type: array + items: + type: string + description: List of node IDs to execute (partial graph execution) + + PromptResponse: + type: object + description: Server acknowledgement of a workflow submission. Includes the assigned `prompt_id` and current queue position. + properties: + prompt_id: + type: string + format: uuid + description: Unique identifier for the prompt execution + number: + type: number + description: Priority number in the queue + node_errors: + type: object + description: Validation errors keyed by node ID + additionalProperties: + $ref: "#/components/schemas/NodeError" + error: + description: Top-level prompt error (string message or structured error) + oneOf: + - type: string + - $ref: "#/components/schemas/PromptError" + + PromptErrorResponse: + type: object + description: Error response when prompt validation fails + additionalProperties: true + + PromptError: + type: object + description: Structured prompt validation error + properties: + type: + type: string + message: + type: string + details: + type: string + + Error: + type: object + description: Detailed node-level error + properties: + type: + type: string + message: + type: string + details: + type: string + extra_info: + type: object + properties: + input_name: + type: string + additionalProperties: true + + NodeError: + type: object + description: Error details for a single node + properties: + errors: + type: array + items: + $ref: "#/components/schemas/Error" + class_type: + type: string + description: The node's class type + dependent_outputs: + type: array + items: {} + + PromptInfo: + type: object + description: Summary of a queued or recently-executed prompt, as returned by the queue and history endpoints. + properties: + exec_info: + type: object + properties: + queue_remaining: + type: integer + description: Number of items remaining in the queue + + # ------------------------------------------------------------------- + # Queue + # ------------------------------------------------------------------- + QueueInfo: + type: object + description: Queue information with pending and running items + properties: + queue_running: + type: array + description: Currently running queue items + items: + type: array + description: | + Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] + items: {} + prefixItems: + - type: number + description: Priority number + - type: string + format: uuid + description: prompt_id + - type: object + description: prompt graph + additionalProperties: true + - type: object + description: extra_data + additionalProperties: true + - type: array + description: outputs_to_execute (list of output node IDs) + items: + type: string + - type: object + description: sensitive data (may be omitted) + additionalProperties: true + queue_pending: + type: array + description: Pending queue items (oldest first) + items: + type: array + description: | + Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] + items: {} + prefixItems: + - type: number + description: Priority number + - type: string + format: uuid + description: prompt_id + - type: object + description: prompt graph + additionalProperties: true + - type: object + description: extra_data + additionalProperties: true + - type: array + description: outputs_to_execute (list of output node IDs) + items: + type: string + - type: object + description: sensitive data (may be omitted) + additionalProperties: true + + QueueManageRequest: + type: object + description: Request to clear or delete from queue + properties: + clear: + type: boolean + description: If true, clear all pending items + delete: + type: array + items: + type: string + description: Array of prompt IDs to delete from queue + + # ------------------------------------------------------------------- + # History + # ------------------------------------------------------------------- + HistoryEntry: + type: object + description: A single execution history entry + properties: + prompt: + type: array + description: | + Prompt tuple: [number, prompt_id, prompt_graph, extra_data, output_node_ids] + items: {} + outputs: + type: object + description: Output data from execution keyed by node ID + additionalProperties: true + status: + type: object + description: Execution status (status_str, completed, messages, etc.) + additionalProperties: true + meta: + type: object + description: Metadata about the execution and nodes + additionalProperties: true + + HistoryManageRequest: + type: object + description: Request to clear or delete history entries + properties: + clear: + type: boolean + description: If true, clear all history + delete: + type: array + items: + type: string + description: Array of prompt IDs to delete from history + + # ------------------------------------------------------------------- + # Jobs + # ------------------------------------------------------------------- + JobEntry: + type: object + description: Lightweight job data for list views + required: + - id + - status + properties: + id: + type: string + format: uuid + description: Unique job identifier (same as prompt_id) + status: + type: string + description: Current job status + create_time: + type: number + description: Job creation timestamp + execution_start_time: + type: number + description: Workflow execution start timestamp + execution_end_time: + type: number + description: Workflow execution end timestamp + preview_output: + type: object + additionalProperties: true + description: Primary preview output + outputs_count: + type: integer + description: Total number of output files + + JobDetailResponse: + type: object + description: Full job details including workflow and outputs + required: + - id + - status + properties: + id: + type: string + format: uuid + status: + type: string + workflow: + type: object + additionalProperties: true + description: Full ComfyUI workflow + outputs: + type: object + additionalProperties: true + description: Full outputs object from execution + execution_error: + $ref: "#/components/schemas/ExecutionError" + create_time: + type: number + update_time: + type: number + execution_start_time: + type: number + execution_end_time: + type: number + preview_output: + type: object + additionalProperties: true + outputs_count: + type: integer + execution_status: + type: object + additionalProperties: true + execution_meta: + type: object + additionalProperties: true + + ExecutionError: + type: object + description: Detailed execution error from ComfyUI + properties: + node_id: + type: string + description: ID of the node that failed + node_type: + type: string + description: Type name of the node + exception_message: + type: string + description: Human-readable error message + exception_type: + type: string + description: Python exception type + traceback: + type: array + items: + type: string + description: Traceback lines + current_inputs: + type: object + additionalProperties: true + current_outputs: + type: object + additionalProperties: true + + PaginationInfo: + type: object + description: Pagination metadata returned alongside list responses. + properties: + offset: + type: integer + limit: + type: integer + total: + type: integer + has_more: + type: boolean + + # ------------------------------------------------------------------- + # Upload / View + # ------------------------------------------------------------------- + UploadResult: + type: object + description: Response body returned by the image/mask upload endpoints, describing where the uploaded file now lives. + properties: + name: + type: string + description: Saved filename (may be renamed to avoid collisions) + subfolder: + type: string + description: Subfolder the file was saved to + type: + type: string + description: Directory type (input, temp) + + # ------------------------------------------------------------------- + # System + # ------------------------------------------------------------------- + DeviceStats: + type: object + description: GPU/compute device statistics + required: + - name + - type + - index + properties: + name: + type: string + description: Device name + type: + type: string + description: Device type (cuda, mps, cpu, etc.) + index: + type: number + description: Device index + vram_total: + type: number + description: Total VRAM in bytes + vram_free: + type: number + description: Free VRAM in bytes + torch_vram_total: + type: number + description: Total PyTorch-managed VRAM in bytes + torch_vram_free: + type: number + description: Free PyTorch-managed VRAM in bytes + + SystemStatsResponse: + type: object + description: Hardware, VRAM, Python, and ComfyUI version information for the running process. + required: + - system + - devices + properties: + system: + type: object + required: + - os + - python_version + - embedded_python + - comfyui_version + - pytorch_version + - argv + - ram_total + - ram_free + properties: + os: + type: string + description: Operating system + python_version: + type: string + description: Python version + embedded_python: + type: boolean + description: Whether using embedded Python + comfyui_version: + type: string + description: ComfyUI version string + pytorch_version: + type: string + description: PyTorch version + required_frontend_version: + type: string + description: Required frontend version + argv: + type: array + items: + type: string + description: Command line arguments + ram_total: + type: number + description: Total RAM in bytes + ram_free: + type: number + description: Free RAM in bytes + installed_templates_version: + type: string + nullable: true + description: Version of the currently installed workflow templates + required_templates_version: + type: string + nullable: true + description: Minimum required workflow templates version for this ComfyUI build + devices: + type: array + items: + $ref: "#/components/schemas/DeviceStats" + + # ------------------------------------------------------------------- + # Node / Object Info + # ------------------------------------------------------------------- + NodeInfo: + type: object + description: 'Definition of a registered node class: its inputs, outputs, category, and display metadata.' + properties: + input: + type: object + description: Input specifications (required and optional groups) + additionalProperties: true + input_order: + type: object + description: Ordered input names per group + additionalProperties: + type: array + items: + type: string + output: + type: array + items: + type: string + description: Output type names + output_is_list: + type: array + items: + type: boolean + description: Whether each output is a list + output_name: + type: array + items: + type: string + description: Display names of outputs + name: + type: string + description: Internal class name + display_name: + type: string + description: Human-readable display name + description: + type: string + description: Node description + python_module: + type: string + description: Python module implementing the node + category: + type: string + description: Node category path + output_node: + type: boolean + description: Whether this is an output node + output_tooltips: + type: array + items: + type: string + description: Tooltips for each output + deprecated: + type: boolean + description: Whether the node is deprecated + experimental: + type: boolean + description: Whether the node is experimental + api_node: + type: boolean + description: Whether this is an API node + is_input_list: + type: boolean + description: Whether the node accepts list inputs + dev_only: + type: boolean + description: Whether the node is developer-only (hidden in production UI) + has_intermediate_output: + type: boolean + description: Whether the node emits intermediate output during execution + search_aliases: + type: array + items: + type: string + description: Alternative search terms for finding this node + essentials_category: + type: string + description: Category override used by the essentials pack + + # ------------------------------------------------------------------- + # Models + # ------------------------------------------------------------------- + ModelFolder: + type: object + description: A configured model folder and the list of disk paths it resolves to. + required: + - name + - folders + properties: + name: + type: string + description: Model folder type name (e.g. "checkpoints") + folders: + type: array + items: + type: string + description: Filesystem paths for this model type + + ModelFile: + type: object + description: A single model file in a folder, with filesystem metadata. + required: + - name + - pathIndex + properties: + name: + type: string + description: Model filename + pathIndex: + type: integer + description: Index into the folder's paths array + modified: + type: number + description: File modification timestamp + created: + type: number + description: File creation timestamp + size: + type: integer + format: int64 + description: File size in bytes + + # ------------------------------------------------------------------- + # Subgraphs + # ------------------------------------------------------------------- + GlobalSubgraphInfo: + type: object + description: Metadata for a global subgraph blueprint (without full data) + required: + - source + - name + - info + properties: + source: + type: string + description: Source type ("templates" or "custom_node") + name: + type: string + description: Display name of the subgraph blueprint + info: + type: object + description: Additional information about the subgraph + required: + - node_pack + properties: + node_pack: + type: string + description: The node pack/module providing this subgraph + data: + type: string + description: The full subgraph JSON data (may be empty in list view) + + GlobalSubgraphData: + type: object + description: Full data for a global subgraph blueprint + required: + - source + - name + - info + - data + properties: + source: + type: string + description: Source type ("templates" or "custom_node") + name: + type: string + description: Display name of the subgraph blueprint + info: + type: object + description: Additional information about the subgraph + required: + - node_pack + properties: + node_pack: + type: string + description: The node pack/module providing this subgraph + data: + type: string + description: The full subgraph JSON data as a string + + # ------------------------------------------------------------------- + # Userdata + # ------------------------------------------------------------------- + UserDataResponse: + description: | + Response body for the POST endpoints `/api/userdata/{file}` and + `/api/userdata/{file}/move/{dest}`. Returns a single item whose + shape depends on the `full_info` query parameter. + x-variant-selector: + full_info=true: file-info object (`GetUserDataResponseFullFile`) + default: relative path string + oneOf: + - $ref: "#/components/schemas/GetUserDataResponseFullFile" + - type: string + description: Relative path of the written or moved file. Returned when `full_info` is absent or false. + + ListUserdataResponse: + description: | + Response body for `GET /api/userdata`. The array item shape is + determined by the `full_info` and `split` query parameters. + x-variant-selector: + full_info=true: array of file-info objects (`GetUserDataResponseFullFile`) + split=true: array of `[relative_path, ...path_components]` arrays + default: array of relative path strings + oneOf: + - type: array + items: + $ref: "#/components/schemas/GetUserDataResponseFullFile" + description: Returned when `full_info=true`. + - type: array + items: + type: array + items: + type: string + minItems: 2 + description: | + Returned when `split=true` and `full_info=false`. Each inner + array is `[relative_path, ...path_components]`. + - type: array + items: + type: string + description: Default shape — array of file paths relative to the user data root. + + GetUserDataResponseFullFile: + type: object + description: A single entry in a full-info user data listing. + properties: + path: + type: string + description: File name or path relative to the user directory + created: + type: number + description: Unix timestamp of file creation + size: + type: integer + description: File size in bytes + modified: + type: integer + format: int64 + description: Unix timestamp of last modification in milliseconds + + # ------------------------------------------------------------------- + # Assets + # ------------------------------------------------------------------- + Asset: + type: object + description: A registered asset — an input/output file tracked in the asset database with content hash and metadata. + required: + - id + - name + - size + - created_at + - updated_at + properties: + id: + type: string + format: uuid + description: Unique identifier for the asset + name: + type: string + description: Name of the asset file + asset_hash: + type: string + description: Blake3 hash of the asset content + pattern: "^blake3:[a-f0-9]{64}$" + size: + type: integer + format: int64 + description: Size of the asset in bytes + mime_type: + type: string + description: MIME type of the asset + tags: + type: array + items: + type: string + description: Tags associated with the asset + user_metadata: + type: object + description: Custom user metadata + additionalProperties: true + metadata: + type: object + description: System-managed metadata (read-only) + additionalProperties: true + readOnly: true + preview_url: + type: string + format: uri + description: URL for asset preview/thumbnail + preview_id: + type: string + format: uuid + description: ID of the preview asset if available + prompt_id: + type: string + format: uuid + description: ID of the prompt that created this asset + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + last_access_time: + type: string + format: date-time + is_immutable: + type: boolean + description: Whether this asset is immutable + + AssetCreated: + description: Response body returned after successfully registering a new asset. + allOf: + - $ref: "#/components/schemas/Asset" + - type: object + required: + - created_new + properties: + created_new: + type: boolean + description: Whether this was a new creation (true) or returned existing (false) + + AssetUpdated: + type: object + description: Response body returned after updating an asset's metadata. + required: + - id + - updated_at + properties: + id: + type: string + format: uuid + name: + type: string + asset_hash: + type: string + pattern: "^blake3:[a-f0-9]{64}$" + tags: + type: array + items: + type: string + mime_type: + type: string + user_metadata: + type: object + additionalProperties: true + updated_at: + type: string + format: date-time + + ListAssetsResponse: + type: object + description: Paginated list of assets. + required: + - assets + - total + - has_more + properties: + assets: + type: array + items: + $ref: "#/components/schemas/Asset" + total: + type: integer + has_more: + type: boolean + + TagInfo: + type: object + description: A tag known to the asset database, with the number of assets bearing it. + required: + - name + - count + properties: + name: + type: string + count: + type: integer + + ListTagsResponse: + type: object + description: Flat list of all tags, with counts. + required: + - tags + - total + - has_more + properties: + tags: + type: array + items: + $ref: "#/components/schemas/TagInfo" + total: + type: integer + has_more: + type: boolean + + AssetTagHistogramResponse: + type: object + description: Tags that would refine a filtered asset query, with the count of assets each tag would additionally select. + required: + - tag_counts + properties: + tag_counts: + type: object + additionalProperties: + type: integer + description: Map of tag names to occurrence counts + + TagsModificationResponse: + type: object + description: Response body returned after adding or removing tags on an asset. + required: + - total_tags + properties: + added: + type: array + items: + type: string + description: Tags successfully added + removed: + type: array + items: + type: string + description: Tags successfully removed + already_present: + type: array + items: + type: string + description: Tags already present (for add) + not_present: + type: array + items: + type: string + description: Tags not present (for remove) + total_tags: + type: array + items: + type: string + description: All tags on the asset after the operation + + # ------------------------------------------------------------------- + # Result / Output types + # ------------------------------------------------------------------- + ResultItem: + type: object + description: A single output file reference + properties: + filename: + type: string + subfolder: + type: string + type: + type: string + enum: [input, output, temp] + display_name: + type: string + + NodeOutputs: + type: object + description: | + Outputs from a single node execution. Known keys are listed below, + but custom nodes may add arbitrary keys (additionalProperties). + properties: + images: + type: array + items: + $ref: "#/components/schemas/ResultItem" + audio: + type: array + items: + $ref: "#/components/schemas/ResultItem" + video: + type: array + items: + $ref: "#/components/schemas/ResultItem" + animated: + type: array + items: + type: boolean + text: + oneOf: + - type: string + - type: array + items: + type: string + additionalProperties: true + + TerminalSize: + type: object + description: Terminal dimensions + properties: + cols: + type: number + row: + type: number + + LogEntry: + type: object + description: A single log entry + properties: + t: + type: string + description: Timestamp + m: + type: string + description: Log message + + StatusWsMessageStatus: + type: object + description: Inner payload of a `status` WebSocket message, describing the execution queue state. + properties: + exec_info: + type: object + required: + - queue_remaining + properties: + queue_remaining: + type: integer + + StatusWsMessage: + type: object + description: Initial status message sent on connect + queue status updates + properties: + status: + $ref: "#/components/schemas/StatusWsMessageStatus" + sid: + type: string + description: Session ID assigned by the server + + ProgressWsMessage: + type: object + description: Node execution progress (step N of M) + required: + - value + - max + - prompt_id + - node + properties: + value: + type: integer + description: Current step + max: + type: integer + description: Total steps + prompt_id: + type: string + node: + type: string + description: Node ID currently executing + + ProgressTextWsMessage: + type: object + description: Text-based progress update from a node + properties: + nodeId: + type: string + text: + type: string + prompt_id: + type: string + + NodeProgressState: + type: object + description: Progress state for a single node + properties: + value: + type: number + max: + type: number + state: + type: string + enum: [pending, running, finished, error] + node_id: + type: string + prompt_id: + type: string + display_node_id: + type: string + parent_node_id: + type: string + real_node_id: + type: string + + ProgressStateWsMessage: + type: object + description: Bulk progress state for all nodes in a prompt + required: + - prompt_id + - nodes + properties: + prompt_id: + type: string + nodes: + type: object + description: Map of node ID to progress state + additionalProperties: + $ref: "#/components/schemas/NodeProgressState" + + ExecutingWsMessage: + type: object + description: Fired when a node begins execution + required: + - node + - display_node + - prompt_id + properties: + node: + type: string + description: Node ID + display_node: + type: string + description: Display node ID (may differ for subgraphs) + prompt_id: + type: string + + ExecutedWsMessage: + type: object + description: Fired when a node completes execution with output + required: + - node + - display_node + - prompt_id + - output + properties: + node: + type: string + display_node: + type: string + prompt_id: + type: string + output: + $ref: "#/components/schemas/NodeOutputs" + merge: + type: boolean + description: Whether to merge with existing output + + ExecutionWsMessageBase: + type: object + description: Base fields for execution lifecycle messages + required: + - prompt_id + - timestamp + properties: + prompt_id: + type: string + timestamp: + type: integer + description: Unix timestamp in milliseconds + + ExecutionStartWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + description: Fired when prompt execution begins + + ExecutionSuccessWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + description: Fired when prompt execution completes successfully + + ExecutionCachedWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + nodes: + type: array + items: + type: string + description: List of node IDs that were cached + description: Fired when nodes are served from cache + + ExecutionInterruptedWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + node_id: + type: string + node_type: + type: string + executed: + type: array + items: + type: string + description: Node IDs that completed before interruption + description: Fired when execution is interrupted by user + + ExecutionErrorWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + node_id: + type: string + node_type: + type: string + executed: + type: array + items: + type: string + exception_message: + type: string + exception_type: + type: string + traceback: + type: array + items: + type: string + current_inputs: {} + current_outputs: {} + description: Fired when a node throws an exception during execution + + LogsWsMessage: + type: object + description: Streaming log entries from the server + properties: + size: + $ref: "#/components/schemas/TerminalSize" + entries: + type: array + items: + $ref: "#/components/schemas/LogEntry" + + NotificationWsMessage: + type: object + description: Server notification (e.g. model download complete) + properties: + value: + type: string + id: + type: string + + FeatureFlagsWsMessage: + type: object + description: Feature flags sent on connect + additionalProperties: true + + AssetDownloadWsMessage: + type: object + description: Asset download progress + required: + - task_id + - asset_name + - bytes_total + - bytes_downloaded + - progress + - status + properties: + task_id: + type: string + asset_name: + type: string + bytes_total: + type: number + bytes_downloaded: + type: number + progress: + type: number + description: 0.0 to 1.0 + status: + type: string + enum: [created, running, completed, failed] + asset_id: + type: string + error: + type: string + + AssetExportWsMessage: + type: object + description: Bulk asset export progress + required: + - task_id + - assets_total + - assets_attempted + - assets_failed + - bytes_total + - bytes_processed + - progress + - status + properties: + task_id: + type: string + export_name: + type: string + assets_total: + type: number + assets_attempted: + type: number + assets_failed: + type: number + bytes_total: + type: number + bytes_processed: + type: number + progress: + type: number + description: 0.0 to 1.0 + status: + type: string + enum: [created, running, completed, failed] + error: + type: string From 7636599389a6798e813d6036fb0dcf08295e7971 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 24 Apr 2026 16:54:10 +0300 Subject: [PATCH 32/35] chore(api-nodes): add upcoming-deprecation notice to Sora nodes (#13549) --- comfy_api_nodes/nodes_sora.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index afc18bb25..4d9075dcf 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -33,9 +33,13 @@ class OpenAIVideoSora2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIVideoSora2", - display_name="OpenAI Sora - Video", + display_name="OpenAI Sora - Video (Deprecated)", category="api node/video/Sora", - description="OpenAI video and audio generation.", + description=( + "OpenAI video and audio generation.\n\n" + "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " + "This node will be removed from ComfyUI at that time." + ), inputs=[ IO.Combo.Input( "model", From 4304c15e9b4acb45fa9241e8e1723f8ce6397550 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 24 Apr 2026 13:46:10 -0700 Subject: [PATCH 33/35] Properly load higher bit depth videos. (#13542) --- comfy_api/latest/_input_impl/video_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1b4993aa7..bd8090635 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -248,8 +248,8 @@ class VideoFromFile(VideoInput): continue if self.__duration and frame.pts >= end_pts: break - img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) - img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) + img = frame.to_ndarray(format='gbrpf32le') # shape: (H, W, 3) + img = torch.from_numpy(img) frames.append(img) images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) From 5e3f15a830ff27d3563ef4b43e9f6a0321ea36cd Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Sat, 25 Apr 2026 09:21:39 +0900 Subject: [PATCH 34/35] Bump comfyui-frontend-package to 1.42.15 (#13556) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 346ce4b76..6c7457e03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.42.14 +comfyui-frontend-package==1.42.15 comfyui-workflow-templates==0.9.62 comfyui-embedded-docs==0.4.4 torch From df22bcd5e192ce0b1ae09eaf2e423d0a12cf6638 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 25 Apr 2026 18:02:58 -0700 Subject: [PATCH 35/35] Support loading the alpha channel of videos. (#13564) Not exposed in nodes yet. --- comfy_api/latest/_input_impl/video_types.py | 25 ++++++++++++++++----- comfy_api/latest/_util/video_types.py | 5 ++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index bd8090635..eb4d3701d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -240,19 +240,34 @@ class VideoFromFile(VideoInput): start_time = self.__start_time # Get video frames frames = [] + alphas = None start_pts = int(start_time / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base) container.seek(start_pts, stream=video_stream) + image_format = 'gbrpf32le' for frame in container.decode(video_stream): + if alphas is None: + for comp in frame.format.components: + if comp.is_alpha: + alphas = [] + image_format = 'gbrapf32le' + break + if frame.pts < start_pts: continue if self.__duration and frame.pts >= end_pts: break - img = frame.to_ndarray(format='gbrpf32le') # shape: (H, W, 3) - img = torch.from_numpy(img) - frames.append(img) - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if alphas is None: + frames.append(torch.from_numpy(img)) + else: + frames.append(torch.from_numpy(img[..., :-1])) + alphas.append(torch.from_numpy(img[..., -1:])) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) + if alphas is not None: + alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) @@ -295,7 +310,7 @@ class VideoFromFile(VideoInput): }) metadata = container.metadata - return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata) def get_components(self) -> VideoComponents: if isinstance(self.__file, io.BytesIO): diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index fd3b5a510..c92477f08 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from .._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput, MaskInput class VideoCodec(str, Enum): AUTO = "auto" @@ -48,5 +48,4 @@ class VideoComponents: frame_rate: Fraction audio: Optional[AudioInput] = None metadata: Optional[dict] = None - - + alpha: Optional[MaskInput] = None