From 1a157e1f97d32c27b3b8bd842bfc5e448c240fe7 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:32:43 -0700 Subject: [PATCH 01/42] Reduce LTX VAE VRAM usage and save use cases from OOMs/Tiler (#13013) * ltx: vae: scale the chunk size with the users VRAM Scale this linearly down for users with low VRAM. * ltx: vae: free non-chunking recursive intermediates * ltx: vae: cleanup some intermediates The conv layer can be the VRAM peak and it does a torch.cat. So cleanup the pieces of the cat. Also clear our the cache ASAP as each layer detect its end as this VAE surges in VRAM at the end due to the ended padding increasing the size of the final frame convolutions off-the-books to the chunker. So if all the earlier layers free up their cache it can offset that surge. Its a fragmentation nightmare, and the chance of it having to recache the pyt allocator is very high, but you wont OOM. --- comfy/ldm/lightricks/vae/causal_conv3d.py | 4 ++ .../vae/causal_video_autoencoder.py | 41 +++++++++++++++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index b8341edbc..356394239 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -65,9 +65,13 @@ class CausalConv3d(nn.Module): self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) x = torch.cat(pieces, dim=2) + del pieces + del cached if needs_caching: self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + elif is_end: + self.temporal_cache_state[tid] = (None, True) return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 9f14f64a5..0504140ef 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -297,7 +297,23 @@ class Encoder(nn.Module): module.temporal_cache_state.pop(tid, None) -MAX_CHUNK_SIZE=(128 * 1024 ** 2) +MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 +MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3 +MIN_CHUNK_SIZE = 32 * 1024 ** 2 +MAX_CHUNK_SIZE = 128 * 1024 ** 2 + +def get_max_chunk_size(device: torch.device) -> int: + total_memory = comfy.model_management.get_total_memory(dev=device) + + if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING: + return MIN_CHUNK_SIZE + if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING: + return MAX_CHUNK_SIZE + + interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / ( + MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING + ) + return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE)) class Decoder(nn.Module): r""" @@ -525,8 +541,11 @@ class Decoder(nn.Module): timestep_shift_scale = ada_values.unbind(dim=1) output = [] + max_chunk_size = get_max_chunk_size(sample.device) - def run_up(idx, sample, ended): + def run_up(idx, sample_ref, ended): + sample = sample_ref[0] + sample_ref[0] = None if idx >= len(self.up_blocks): sample = self.conv_norm_out(sample) if timestep_shift_scale is not None: @@ -554,13 +573,21 @@ class Decoder(nn.Module): return total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE - samples = torch.chunk(sample, chunks=num_chunks, dim=2) + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size - for chunk_idx, sample1 in enumerate(samples): - run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + run_up(idx + 1, next_sample_ref, ended) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) - run_up(0, sample, True) + for chunk_idx, sample1 in enumerate(samples): + run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) + + run_up(0, [sample], True) sample = torch.cat(output, dim=2) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) From 035414ede49c1b043ea6de054ca512bcbf0f6b35 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Mar 2026 14:34:39 -0700 Subject: [PATCH 02/42] Reduce WAN VAE VRAM, Save use cases for OOM/Tiler (#13014) * wan: vae: encoder: Add feature cache layer that corks singles If a downsample only gives you a single frame, save it to the feature cache and return nothing to the top level. This increases the efficiency of cacheability, but also prepares support for going two by two rather than four by four on the frames. * wan: remove all concatentation with the feature cache The loopers are now responsible for ensuring that non-final frames are processes at least two-by-two, elimiating the need for this cat case. * wan: vae: recurse and chunk for 2+2 frames on decode Avoid having to clone off slices of 4 frame chunks and reduce the size of the big 6 frame convolutions down to 4. Save the VRAMs. * wan: encode frames 2x2. Reduce VRAM usage greatly by encoding frames 2 at a time rather than 4. * wan: vae: remove cloning The loopers now control the chunking such there is noever more than 2 frames, so just cache these slices directly and avoid the clone allocations completely. * wan: vae: free consumer caller tensors on recursion * wan: vae: restyle a little to match LTX --- comfy/ldm/wan/vae.py | 180 +++++++++++++++++++------------------------ 1 file changed, 81 insertions(+), 99 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 71f73c64e..a96b83c6c 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -99,7 +99,7 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -109,22 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -145,19 +130,24 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x - feat_idx[0] += 1 + + deferred_x = feat_cache[idx + 1] + if deferred_x is not None: + x = torch.cat([deferred_x, x], 2) + feat_cache[idx + 1] = None + + if x.shape[2] == 1 and not final: + feat_cache[idx + 1] = x + x = None + + feat_idx[0] += 2 return x @@ -177,19 +167,12 @@ class ResidualBlock(nn.Module): self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -213,7 +196,7 @@ class AttentionBlock(nn.Module): self.proj = ops.Conv2d(dim, dim, 1) self.optimized_attention = vae_attention() - def forward(self, x): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') @@ -283,17 +266,10 @@ class Encoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -303,14 +279,16 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache, feat_idx, final=final) + if x is None: + return None else: x = layer(x) ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -318,14 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -393,14 +364,7 @@ class Decoder3d(nn.Module): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -409,42 +373,56 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + out_chunks = [] + + def run_up(layer_idx, x_ref, feat_idx): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_idx.copy(), + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) else: x = layer(x) - return x + + next_x_ref = [x] + del x + run_up(layer_idx + 1, next_x_ref, feat_idx) + + run_up(0, [x], feat_idx) + return out_chunks -def count_conv3d(model): +def count_cache_layers(model): count = 0 for m in model.modules(): - if isinstance(m, CausalConv3d): + if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'): count += 1 return count @@ -482,11 +460,12 @@ class WanVAE(nn.Module): conv_idx = [0] ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + t = 1 + ((t - 1) // 4) * 4 + iter_ = 1 + (t - 1) // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.encoder) - ## 对encode输入的x,按时间拆分为1、4、4、4.... + feat_map = [None] * count_cache_layers(self.encoder) + ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整) for i in range(iter_): conv_idx = [0] if i == 0: @@ -496,20 +475,23 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, - feat_idx=conv_idx) + feat_idx=conv_idx, + final=(i == (iter_ - 1))) + if out_ is None: + continue out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) return mu def decode(self, z): - conv_idx = [0] # z: [b,c,t,h,w] - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_cache_layers(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] @@ -520,8 +502,8 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2) From 8b9d039f26f5230ab3d3d6d9dd5d55590681b970 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Wed, 18 Mar 2026 07:17:03 +0900 Subject: [PATCH 03/42] bump manager version to 4.1b6 (#13022) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 1c5e8f071..5b06b56f6 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b5 \ No newline at end of file +comfyui_manager==4.1b6 \ No newline at end of file From 735a0465e5daf1f77909b553b02a9d16d1671be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 18 Mar 2026 02:20:49 +0200 Subject: [PATCH 04/42] Inplace VAE output processing to reduce peak RAM consumption. (#13028) --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 4d427bb9a..652e76d3e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -455,7 +455,7 @@ class VAE: self.output_channels = 3 self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 - self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False From 68d542cc0602132d3d2fe624ee7077e44b0fb0ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:46:22 -0700 Subject: [PATCH 05/42] Fix case where pixel space VAE could cause issues. (#13030) --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 652e76d3e..df0c4d1d1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -952,8 +952,8 @@ class VAE: batch_number = max(1, batch_number) for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype())) + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) pixel_samples[x:x+batch_number] = out From cad24ce26278a72095d33a2b4391572573201542 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:59:10 -0700 Subject: [PATCH 06/42] cascade: remove dead weight init code (#13026) This weight init process is fully shadowed be the weight load and doesnt work in dynamic_vram were the weight allocation is deferred. --- comfy/ldm/cascade/stage_a.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 145e6e69a..e4e30cacd 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -136,16 +136,7 @@ class ResBlock(nn.Module): ops.Linear(c_hidden, c), ) - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) From b941913f1d2d11dc69c098a375309b13c13bca23 Mon Sep 17 00:00:00 2001 From: Anton Bukov Date: Wed, 18 Mar 2026 05:21:32 +0400 Subject: [PATCH 07/42] fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#12809) On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because CPU and GPU share unified memory. However, `text_encoder_device()` only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text encoders to fall back to CPU on MPS devices. Adding `VRAMState.SHARED` to the condition allows non-quantized text encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing significant speedup for text encoding and prompt generation. Note: quantized models (fp4/fp8) that use float8_e4m3fn internally will still fall back to CPU via the `supports_cast()` check in `CLIP.__init__()`, since MPS does not support fp8 dtypes. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..5f2e6ef67 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1003,7 +1003,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From 06957022d4cc6f91e101cf5afdd421e462f820c0 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 18 Mar 2026 19:21:58 +0200 Subject: [PATCH 08/42] fix(api-nodes): add support for "thought_image" in Nano Banana 2 and corrected price badges (#13038) --- comfy_api_nodes/apis/gemini.py | 1 + comfy_api_nodes/nodes_gemini.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 639035fef..22879fe18 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -67,6 +67,7 @@ class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) + thought: bool | None = Field(None) class GeminiTextPart(BaseModel): diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 8225ea67e..25d747e76 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( $m := widgets.model; $r := widgets.resolution; $isFlash := $contains($m, "nano banana 2"); - $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; $prices := $isFlash ? $flashPrices : $proPrices; {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} @@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/*") for part in parts: + if (part.thought is True) != thought: + continue if part.inlineData: image_data = base64.b64decode(part.inlineData.data) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) @@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode): outputs=[ IO.Image.Output(), IO.String.Output(), + IO.Image.Output( + display_name="thought_image", + tooltip="First image from the model's thinking process. " + "Only available with thinking_level HIGH and IMAGE+TEXT modality.", + ), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput( + await get_image_from_response(response), + get_text_from_response(response), + await get_image_from_response(response, thought=True), + ) class GeminiExtension(ComfyExtension): From b67ed2a45fad8322629289b3347ea15f8926cd45 Mon Sep 17 00:00:00 2001 From: Alexander Brown Date: Wed, 18 Mar 2026 13:36:39 -0700 Subject: [PATCH 09/42] Update comfyui-frontend-package version to 1.41.21 (#13035) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0ce163f71..ad0344ed4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.20 +comfyui-frontend-package==1.41.21 comfyui-workflow-templates==0.9.26 comfyui-embedded-docs==0.4.3 torch From dcd659590faac35a1ac36393077f4ab8aac3fea8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:14:18 -0700 Subject: [PATCH 10/42] Make more intermediate values follow the intermediate dtype. (#13051) --- comfy/sample.py | 4 ++-- comfy/sd1_clip.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index a2a39b527..e9c2259ab 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d89550840..f970510ad 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,7 +46,7 @@ class ClipTokenWeightEncoder: out, pooled = o[:2] if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) + first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()) else: first_pooled = pooled @@ -63,16 +63,16 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled) else: - r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled) if len(o) > 2: extra = {} for k in o[2]: v = o[2][k] if k == "attention_mask": - v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()) extra[k] = v r = r + (extra,) From 9fff091f354815378b913c6e0ee3a39c0ed79a70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 19 Mar 2026 00:32:26 +0200 Subject: [PATCH 11/42] Further Reduce LTX VAE decode peak RAM usage (#13052) --- .../vae/causal_video_autoencoder.py | 42 +++++++++++++++---- comfy/sd.py | 19 +++++++-- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 0504140ef..f7aae26da 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -473,6 +473,17 @@ class Decoder(nn.Module): self.gradient_checkpointing = False + # Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset) + ts, hs, ws, to = 1, 1, 1, 0 + for block in self.up_blocks: + if isinstance(block, DepthToSpaceUpsample): + ts *= block.stride[0] + hs *= block.stride[1] + ws *= block.stride[2] + if block.stride[0] > 1: + to = to * block.stride[0] + 1 + self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to) + self.timestep_conditioning = timestep_conditioning if timestep_conditioning: @@ -494,11 +505,15 @@ class Decoder(nn.Module): ) - # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + def decode_output_shape(self, input_shape): + c, (ts, hs, ws), to = self._output_scale + return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] @@ -540,7 +555,13 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) - output = [] + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] + max_chunk_size = get_max_chunk_size(sample.device) def run_up(idx, sample_ref, ended): @@ -556,7 +577,10 @@ class Decoder(nn.Module): mark_conv3d_ended(self.conv_out) sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: - output.append(sample.to(comfy.model_management.intermediate_device())) + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t return up_block = self.up_blocks[idx] @@ -588,11 +612,8 @@ class Decoder(nn.Module): run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) run_up(0, [sample], True) - sample = torch.cat(output, dim=2) - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - - return sample + return output_buffer def forward(self, *args, **kwargs): try: @@ -1226,7 +1247,10 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x): + def decode_output_shape(self, input_shape): + return self.decoder.decode_output_shape(input_shape) + + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer) diff --git a/comfy/sd.py b/comfy/sd.py index df0c4d1d1..1f9510959 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -951,12 +951,23 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if hasattr(self.first_stage_model, 'decode_output_shape'): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True + for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number] = out + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) except Exception as e: model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") From 56ff88f9511c4e25cd8ac08b2bfcd21c8ad83121 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:35:25 -0700 Subject: [PATCH 12/42] Fix regression. (#13053) --- comfy/sd1_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index f970510ad..a85170b26 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -72,7 +72,7 @@ class ClipTokenWeightEncoder: for k in o[2]: v = o[2][k] if k == "attention_mask": - v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()) + v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device()) extra[k] = v r = r + (extra,) From f6b869d7d35f7160bf2fdeabaed378d737834540 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:42:28 -0700 Subject: [PATCH 13/42] fp16 intermediates doen't work for some text enc models. (#13056) --- comfy/sd1_clip.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index a85170b26..0eb30df27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,7 +46,7 @@ class ClipTokenWeightEncoder: out, pooled = o[:2] if pooled is not None: - first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()) + first_pooled = pooled[0:1].to(device=model_management.intermediate_device()) else: first_pooled = pooled @@ -63,9 +63,9 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled) + r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled) else: - r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled) + r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled) if len(o) > 2: extra = {} From fabed694a2198b1662d521b1c47e11e625601ebe Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 19 Mar 2026 09:58:47 -0700 Subject: [PATCH 14/42] ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062) * ltx: vae: add cache state to downsample block * ltx: vae: Add time stride awareness to causal_conv_3d * ltx: vae: Automate truncation for encoder Other VAEs just truncate without error. Do the same. * sd/ltx: Make chunked_io a flag in its own right Taking this bi-direcitonal, so make it a for-purpose named flag. * ltx: vae: implement chunked encoder + CPU IO chunking People are doing things with big frame counts in LTX including V2V flows. Implement the time-chunked encoder to keep the VRAM down, with the converse of the new CPU pre-allocation technique, where the chunks are brought from the CPU JIT. * ltx: vae-encode: round chunk sizes more strictly Only powers of 2 and multiple of 8 are valid due to cache slicing. --- comfy/ldm/lightricks/vae/causal_conv3d.py | 16 +++- .../vae/causal_video_autoencoder.py | 91 +++++++++++++++---- comfy/sd.py | 11 ++- 3 files changed, 92 insertions(+), 26 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 356394239..7515f0d4e 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -23,6 +23,11 @@ class CausalConv3d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels + if isinstance(stride, int): + self.time_stride = stride + else: + self.time_stride = stride[0] + kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] @@ -58,18 +63,23 @@ class CausalConv3d(nn.Module): pieces = [ cached, x ] if is_end and not causal: pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + input_length = sum([piece.shape[2] for piece in pieces]) + cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride) needs_caching = not is_end - if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + if needs_caching and cache_length == 0: + self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False) needs_caching = False - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + if needs_caching and x.shape[2] >= cache_length: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) x = torch.cat(pieces, dim=2) del pieces del cached if needs_caching: - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) elif is_end: self.temporal_cache_state[tid] = (None, True) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index f7aae26da..1a15cafd0 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -233,10 +233,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]: sample = self.conv_in(sample) checkpoint_fn = ( @@ -247,10 +244,14 @@ class Encoder(nn.Module): for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) + if sample is None or sample.shape[2] == 0: + return None sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + if sample is None or sample.shape[2] == 0: + return None if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] @@ -282,9 +283,35 @@ class Encoder(nn.Module): return sample + def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder + frame_size = sample[:, :, :1, :, :].numel() * sample.element_size() + frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels)) + + outputs = [] + samples = [sample[:, :, :1, :, :]] + if sample.shape[2] > 1: + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) + for chunk_idx, chunk in enumerate(samples): + if chunk_idx == len(samples) - 1: + mark_conv3d_ended(self) + chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device) + output = self._forward_chunk(chunk) + if output is not None: + outputs.append(output) + + return torch_cat_if_needed(outputs, dim=2) + def forward(self, *args, **kwargs): - #No encoder support so just flag the end so it doesnt use the cache. - mark_conv3d_ended(self) try: return self.forward_orig(*args, **kwargs) finally: @@ -737,12 +764,25 @@ class SpaceToDepthDownsample(nn.Module): causal=True, spatial_padding_mode=spatial_padding_mode, ) + self.temporal_cache_state = {} def forward(self, x, causal: bool = True): - if self.stride[0] == 2: + tid = threading.get_ident() + cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None)) + if cached_input is not None: + x = torch_cat_if_needed([cached_input, x], dim=2) + cached_input = None + + if self.stride[0] == 2 and pad_first: x = torch.cat( [x[:, :, :1, :, :], x], dim=2 ) # duplicate first frames for padding + pad_first = False + + if x.shape[2] < self.stride[0]: + cached_input = x + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) + return None # skip connection x_in = rearrange( @@ -757,15 +797,26 @@ class SpaceToDepthDownsample(nn.Module): # conv x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) + if self.stride[0] == 2 and x.shape[2] == 1: + if cached_x is not None: + x = torch_cat_if_needed([cached_x, x], dim=2) + cached_x = None + else: + cached_x = x + x = None - x = x + x_in + if x is not None: + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + cached = add_exchange_cache(x, cached, x_in, dim=2) + + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) return x @@ -1098,6 +1149,8 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): + comfy_has_chunked_io = True + def __init__(self, version=0, config=None): super().__init__() @@ -1240,11 +1293,9 @@ class VideoVAE(nn.Module): } return config - def encode(self, x): - frames_count = x.shape[2] - if ((frames_count - 1) % 8) != 0: - raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") - means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + def encode(self, x, device=None): + x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :] + means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1) return self.per_channel_statistics.normalize(means) def decode_output_shape(self, input_shape): diff --git a/comfy/sd.py b/comfy/sd.py index 1f9510959..b5e7c93a9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -953,7 +953,7 @@ class VAE: # Pre-allocate output for VAEs that support direct buffer writes preallocated = False - if hasattr(self.first_stage_model, 'decode_output_shape'): + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) preallocated = True @@ -1038,8 +1038,13 @@ class VAE: batch_number = max(1, batch_number) samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype()) + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out From 6589562ae3e35dd7694f430629a805306157f530 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:01:12 -0700 Subject: [PATCH 15/42] ltx: vae: implement chunked encoder + CPU IO chunking (Big VRAM reductions) (#13062) * ltx: vae: add cache state to downsample block * ltx: vae: Add time stride awareness to causal_conv_3d * ltx: vae: Automate truncation for encoder Other VAEs just truncate without error. Do the same. * sd/ltx: Make chunked_io a flag in its own right Taking this bi-direcitonal, so make it a for-purpose named flag. * ltx: vae: implement chunked encoder + CPU IO chunking People are doing things with big frame counts in LTX including V2V flows. Implement the time-chunked encoder to keep the VRAM down, with the converse of the new CPU pre-allocation technique, where the chunks are brought from the CPU JIT. * ltx: vae-encode: round chunk sizes more strictly Only powers of 2 and multiple of 8 are valid due to cache slicing. From ab14541ef7965dc61956c447d3066dd3d5c9f33b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:03:20 -0700 Subject: [PATCH 16/42] memory: Add more exclusion criteria to pinned read (#13067) --- comfy/memory_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 563224098..f9078fe7c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination): if (destination.device.type != "cpu" or file_obj is None or threading.get_ident() != info.thread_id - or destination.numel() * destination.element_size() < info.size): + or destination.numel() * destination.element_size() < info.size + or tensor.numel() * tensor.element_size() != info.size + or tensor.storage_offset() != 0 + or not tensor.is_contiguous()): return False if info.size == 0: From fd0261d2bc0c32fa6c21d20994702f44fd927d4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 19 Mar 2026 19:29:34 +0200 Subject: [PATCH 17/42] Reduce tiled decode peak memory (#13050) --- comfy/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 13b7ca6c8..78c491b98 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1135,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out = output[b:b+1].zero_() + out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] @@ -1151,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) @@ -1174,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): From 8458ae2686a8d62ee206d3903123868425a4e6a7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 19 Mar 2026 12:27:55 -0700 Subject: [PATCH 18/42] =?UTF-8?q?Revert=20"fix:=20run=20text=20encoders=20?= =?UTF-8?q?on=20MPS=20GPU=20instead=20of=20CPU=20for=20Apple=20Silicon=20(?= =?UTF-8?q?#=E2=80=A6"=20(#13070)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit b941913f1d2d11dc69c098a375309b13c13bca23. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5f2e6ef67..2c250dacc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1003,7 +1003,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From 82b868a45a753c875677091d0a91bb5bbaf04cbe Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 19 Mar 2026 19:30:27 -0700 Subject: [PATCH 19/42] Fix VRAM leak in tiler fallback in video VAEs (#13073) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sd: soft_empty_cache on tiler fallback This doesnt cost a lot and creates the expected VRAM reduction in resource monitors when you fallback to tiler. * wan: vae: Don't recursion in local fns (move run_up) Moved Decoder3d’s recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback. * ltx: vae: Don't recursion in local fns (move run_up) Mov the recursive run_up out of forward into a class method to avoid nested closure self-reference cycles. This avoids cyclic garbage that delays garbage of tensors which in turn delays VRAM release before tiled fallback. --- .../vae/causal_video_autoencoder.py | 96 +++++++++---------- comfy/ldm/wan/vae.py | 74 +++++++------- comfy/sd.py | 2 + 3 files changed, 88 insertions(+), 84 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 1a15cafd0..dd1dfeba0 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -536,6 +536,53 @@ class Decoder(nn.Module): c, (ts, hs, ws), to = self._output_scale return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + sample = sample_ref[0] + sample_ref[0] = None + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t + return + + up_block = self.up_blocks[idx] + if ended: + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + def forward_orig( self, sample: torch.FloatTensor, @@ -591,54 +638,7 @@ class Decoder(nn.Module): max_chunk_size = get_max_chunk_size(sample.device) - def run_up(idx, sample_ref, ended): - sample = sample_ref[0] - sample_ref[0] = None - if idx >= len(self.up_blocks): - sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale - sample = sample * (1 + scale) + shift - sample = self.conv_act(sample) - if ended: - mark_conv3d_ended(self.conv_out) - sample = self.conv_out(sample, causal=self.causal) - if sample is not None and sample.shape[2] > 0: - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - t = sample.shape[2] - output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) - output_offset[0] += t - return - - up_block = self.up_blocks[idx] - if (ended): - mark_conv3d_ended(up_block) - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - if sample is None or sample.shape[2] == 0: - return - - total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size - - if num_chunks == 1: - # when we are not chunking, detach our x so the callee can free it as soon as they are done - next_sample_ref = [sample] - del sample - run_up(idx + 1, next_sample_ref, ended) - return - else: - samples = torch.chunk(sample, chunks=num_chunks, dim=2) - - for chunk_idx, sample1 in enumerate(samples): - run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1) - - run_up(0, [sample], True) + self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) return output_buffer diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a96b83c6c..deeb8695b 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -360,6 +360,43 @@ class Decoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) + def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + self.run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_cache, + feat_idx.copy(), + out_chunks, + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + next_x_ref = [x] + del x + self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) + def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: @@ -380,42 +417,7 @@ class Decoder3d(nn.Module): out_chunks = [] - def run_up(layer_idx, x_ref, feat_idx): - x = x_ref[0] - x_ref[0] = None - if layer_idx >= len(self.upsamples): - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - cache_x = x[:, :, -CACHE_T:, :, :] - x = layer(x, feat_cache[feat_idx[0]]) - feat_cache[feat_idx[0]] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - out_chunks.append(x) - return - - layer = self.upsamples[layer_idx] - if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: - for frame_idx in range(x.shape[2]): - run_up( - layer_idx, - [x[:, :, frame_idx:frame_idx + 1, :, :]], - feat_idx.copy(), - ) - del x - return - - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - next_x_ref = [x] - del x - run_up(layer_idx + 1, next_x_ref, feat_idx) - - run_up(0, [x], feat_idx) + self.run_up(0, [x], feat_cache, feat_idx, out_chunks) return out_chunks diff --git a/comfy/sd.py b/comfy/sd.py index b5e7c93a9..e207bb0fd 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -978,6 +978,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1059,6 +1060,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 From f49856af57888f60d09f470a6509456f5ee23c99 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 19 Mar 2026 19:34:58 -0700 Subject: [PATCH 20/42] ltx: vae: Fix missing init variable (#13074) Forgot to push this ammendment. Previous test results apply to this. --- comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index dd1dfeba0..998122c85 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -602,6 +602,7 @@ class Decoder(nn.Module): ) timestep_shift_scale = None + scaled_timestep = None if self.timestep_conditioning: assert ( timestep is not None From e4455fd43acd3f975905455ace7497136962968a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 20 Mar 2026 05:05:01 +0200 Subject: [PATCH 21/42] [API Nodes] mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated (#13060) * chore(api-nodes): mark seedream-3-0-t2i and seedance-1-0-lite models as deprecated * fix(api-nodes): fixed old regression in the ByteDanceImageReference node --------- Co-authored-by: Jedrzej Kosinski --- comfy_api_nodes/nodes_bytedance.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 6dbd5984e..de0c22e70 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -47,6 +47,10 @@ SEEDREAM_MODELS = { BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} + +logger = logging.getLogger(__name__) + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: @@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode): price_badge=IO.PriceBadge( expr="""{"type":"usd","usd":0.03}""", ), + is_deprecated=True, ) @classmethod @@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ] return await process_video_task( cls, - payload=Image2VideoTaskCreationRequest(model=model, content=x), + payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -952,6 +957,12 @@ async def process_video_task( payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, estimated_duration: int | None, ) -> IO.NodeOutput: + if payload.model in DEPRECATED_MODELS: + logger.warning( + "Model '%s' is deprecated and will be deactivated on May 13, 2026. " + "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", + payload.model, + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), From 589228e671e84518bf77919ee4e574749ab772c8 Mon Sep 17 00:00:00 2001 From: drozbay <17261091+drozbay@users.noreply.github.com> Date: Thu, 19 Mar 2026 21:42:42 -0600 Subject: [PATCH 22/42] Add slice_cond and per-model context window cond resizing (#12645) * Add slice_cond and per-model context window cond resizing * Fix cond_value.size() call in context window cond resizing * Expose additional advanced inputs for ContextWindowsManualNode Necessary for WanAnimate context windows workflow, which needs cond_retain_index_list = 0 to work properly with its reference input. --------- --- comfy/context_windows.py | 54 ++++++++++++++++++++++++++- comfy/model_base.py | 32 ++++++++++++++++ comfy_extras/nodes_context_windows.py | 4 +- 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b54f7f39a..cb44ee6e8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -93,6 +93,50 @@ class IndexListCallbacks: return {} +def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]): + if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)): + return None + cond_tensor = cond_value.cond + if temporal_dim >= cond_tensor.ndim: + return None + + cond_size = cond_tensor.size(temporal_dim) + + if temporal_scale == 1: + expected_size = x_in.size(window.dim) - temporal_offset + if cond_size != expected_size: + return None + + if temporal_offset == 0 and temporal_scale == 1: + sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list) + return cond_value._copy_with(sliced) + + # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) + if temporal_offset > 0: + indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + indices = [i for i in indices if 0 <= i] + else: + indices = list(window.index_list) + + if not indices: + return None + + if temporal_scale > 1: + scaled = [] + for i in indices: + for k in range(temporal_scale): + si = i * temporal_scale + k + if si < cond_size: + scaled.append(si) + indices = scaled + if not indices: + return None + + idx = tuple([slice(None)] * temporal_dim + [indices]) + sliced = cond_tensor[idx].to(device) + return cond_value._copy_with(sliced) + + @dataclass class ContextSchedule: name: str @@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item[cond_key] = result handled = True break + if not handled and self._model is not None: + result = self._model.resize_cond_for_context_window( + cond_key, cond_value, window, x_in, device, + retain_index_list=self.cond_retain_index_list) + if result is not None: + new_cond_item[cond_key] = result + handled = True if handled: continue if isinstance(cond_value, torch.Tensor): - if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) # Handle audio_embed (temporal dim is 1) @@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC): return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self._model = model self.set_step(timestep, model_options) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) diff --git a/comfy/model_base.py b/comfy/model_base.py index d9d5a9293..88905e191 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -285,6 +285,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1375,6 +1381,12 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1427,6 +1439,12 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1444,6 +1462,14 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + import comfy.context_windows + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1480,6 +1506,12 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 93a5204e1..0e43f2e44 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode): io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), From c646d211be359df56617ffabcdd43cb53e191e97 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 20 Mar 2026 21:23:16 +0200 Subject: [PATCH 23/42] feat(api-nodes): add Quiver SVG nodes (#13047) --- comfy_api_nodes/apis/quiver.py | 43 +++++ comfy_api_nodes/nodes_quiver.py | 291 ++++++++++++++++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100644 comfy_api_nodes/apis/quiver.py create mode 100644 comfy_api_nodes/nodes_quiver.py diff --git a/comfy_api_nodes/apis/quiver.py b/comfy_api_nodes/apis/quiver.py new file mode 100644 index 000000000..bc8708754 --- /dev/null +++ b/comfy_api_nodes/apis/quiver.py @@ -0,0 +1,43 @@ +from pydantic import BaseModel, Field + + +class QuiverImageObject(BaseModel): + url: str = Field(...) + + +class QuiverTextToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + prompt: str = Field(...) + instructions: str | None = Field(default=None) + references: list[QuiverImageObject] | None = Field(default=None, max_length=4) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverImageToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + image: QuiverImageObject = Field(...) + auto_crop: bool | None = Field(default=None) + target_size: int | None = Field(default=None, ge=128, le=4096) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverSVGResponseItem(BaseModel): + svg: str = Field(...) + mime_type: str | None = Field(default="image/svg+xml") + + +class QuiverSVGUsage(BaseModel): + total_tokens: int | None = Field(default=None) + input_tokens: int | None = Field(default=None) + output_tokens: int | None = Field(default=None) + + +class QuiverSVGResponse(BaseModel): + id: str | None = Field(default=None) + created: int | None = Field(default=None) + data: list[QuiverSVGResponseItem] = Field(...) + usage: QuiverSVGUsage | None = Field(default=None) diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py new file mode 100644 index 000000000..61533263f --- /dev/null +++ b/comfy_api_nodes/nodes_quiver.py @@ -0,0 +1,291 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.quiver import ( + QuiverImageObject, + QuiverImageToSVGRequest, + QuiverSVGResponse, + QuiverTextToSVGRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + sync_op, + upload_image_to_comfyapi, + validate_string, +) +from comfy_extras.nodes_images import SVG + + +class QuiverTextToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverTextToSVGNode", + display_name="Quiver Text to SVG", + category="api node/image/Quiver", + description="Generate an SVG from a text prompt using Quiver AI.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired SVG output.", + ), + IO.String.Input( + "instructions", + multiline=True, + default="", + tooltip="Additional style or formatting guidance.", + optional=True, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="ref_", + min=0, + max=4, + ), + tooltip="Up to 4 reference images to guide the generation.", + optional=True, + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + instructions: str = None, + reference_images: IO.Autogrow.Type = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + references = None + if reference_images: + references = [] + for key in reference_images: + url = await upload_image_to_comfyapi(cls, reference_images[key]) + references.append(QuiverImageObject(url=url)) + if len(references) > 4: + raise ValueError("Maximum 4 reference images are allowed.") + + instructions_val = instructions.strip() if instructions else None + if instructions_val == "": + instructions_val = None + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverTextToSVGRequest( + model=model["model"], + prompt=prompt, + instructions=instructions_val, + references=references, + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverImageToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverImageToSVGNode", + display_name="Quiver Image to SVG", + category="api node/image/Quiver", + description="Vectorize a raster image into SVG using Quiver AI.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Input image to vectorize.", + ), + IO.Boolean.Input( + "auto_crop", + default=False, + tooltip="Automatically crop to the dominant subject.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Int.Input( + "target_size", + default=1024, + min=128, + max=4096, + tooltip="Square resize target in pixels.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG vectorization.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + image, + auto_crop: bool, + model: dict, + seed: int, + ) -> IO.NodeOutput: + image_url = await upload_image_to_comfyapi(cls, image) + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverImageToSVGRequest( + model=model["model"], + image=QuiverImageObject(url=image_url), + auto_crop=auto_crop if auto_crop else None, + target_size=model.get("target_size"), + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + QuiverTextToSVGNode, + QuiverImageToSVGNode, + ] + + +async def comfy_entrypoint() -> QuiverExtension: + return QuiverExtension() From 45d5c83a3005e7fc28ce9e4ff04b77875052eb51 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:08:26 -0700 Subject: [PATCH 24/42] Make EmptyImage node follow intermediate device/dtype. (#13079) --- nodes.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index e93fa9767..2c4650a20 100644 --- a/nodes.py +++ b/nodes.py @@ -1966,9 +1966,11 @@ class EmptyImage: CATEGORY = "image" def generate(self, width, height, batch_size=1, color=0): - r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) - g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) - b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype) + g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype) + b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype) return (torch.cat((r, g, b), dim=-1), ) class ImagePadForOutpaint: From 87cda1fc25ca11a55ede88bf264cfe0a20d340ce Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 20 Mar 2026 17:03:42 -0700 Subject: [PATCH 25/42] Move inline comfy.context_windows imports to top-level in model_base.py (#13083) The recent PR that added resize_cond_for_context_window methods to model classes used inline 'import comfy.context_windows' in each method body. This moves that import to the top-level import section, replacing 4 duplicate inline imports with a single top-level one. --- comfy/model_base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 88905e191..43ec93324 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -1383,7 +1384,6 @@ class WAN21_Vace(WAN21): def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): if cond_key == "vace_context": - import comfy.context_windows return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) @@ -1441,7 +1441,6 @@ class WAN21_HuMo(WAN21): def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): if cond_key == "audio_embed": - import comfy.context_windows return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) @@ -1463,7 +1462,6 @@ class WAN22_Animate(WAN21): return out def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - import comfy.context_windows if cond_key == "face_pixel_values": return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) if cond_key == "pose_latents": @@ -1508,7 +1506,6 @@ class WAN22_S2V(WAN21): def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): if cond_key == "audio_embed": - import comfy.context_windows return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) From dc719cde9c448c65242ae2d4ba400ba18c36846f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 20 Mar 2026 20:09:15 -0400 Subject: [PATCH 26/42] ComfyUI version 0.18.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 701f4d66a..a3b7204dc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.17.0" +__version__ = "0.18.0" diff --git a/pyproject.toml b/pyproject.toml index e2ca79be7..6db9b1267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.17.0" +version = "0.18.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From a11f68dd3b5393b6afc37e01c91fa84963d2668a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 20 Mar 2026 20:15:50 -0700 Subject: [PATCH 27/42] Fix canny node not working with fp16. (#13085) --- comfy_extras/nodes_canny.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 5e7c4eabb..648b4279d 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,6 +3,7 @@ from typing_extensions import override import comfy.model_management from comfy_api.latest import ComfyExtension, io +import torch class Canny(io.ComfyNode): @@ -29,8 +30,8 @@ class Canny(io.ComfyNode): @classmethod def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: - output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1) return io.NodeOutput(img_out) From b5d32e6ad23f3deb0cd16b5f2afa81ff92d89e6e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 21 Mar 2026 14:47:42 -0700 Subject: [PATCH 28/42] Fix sampling issue with fp16 intermediates. (#13099) --- comfy/samplers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8be449ef7..0a4d062db 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -985,8 +985,8 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - noise = noise.to(device) - latent_image = latent_image.to(device) + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) sigmas = sigmas.to(device) cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) @@ -1028,6 +1028,7 @@ class CFGGuider: denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) else: denoise_mask = denoise_masks[0] + denoise_mask = denoise_mask.float() self.conds = {} for k in self.original_conds: From 11c15d8832ab8a95ebe31f85c131429978668c76 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 21 Mar 2026 14:53:25 -0700 Subject: [PATCH 29/42] Fix fp16 intermediates giving different results. (#13100) --- comfy/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index e9c2259ab..653829582 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -8,12 +8,12 @@ import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: - return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) unique_inds, inverse = np.unique(noise_inds, return_inverse=True) noises = [] for i in range(unique_inds[-1]+1): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] From 25b6d1d6298c380c1d4de90ff9f38484a84ada19 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:44:35 -0700 Subject: [PATCH 30/42] wan: vae: Fix light/color change (#13101) There was an issue where the resample split was too early and dropped one of the rolling convolutions a frame early. This is most noticable as a lighting/color change between pixel frames 5->6 (latent 2->3), or as a lighting change between the first and last frame in an FLF wan flow. --- comfy/ldm/wan/vae.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index deeb8695b..57b0dabf7 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -376,11 +376,16 @@ class Decoder3d(nn.Module): return layer = self.upsamples[layer_idx] - if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: - for frame_idx in range(x.shape[2]): + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2: + for frame_idx in range(0, x.shape[2], 2): self.run_up( - layer_idx, - [x[:, :, frame_idx:frame_idx + 1, :, :]], + layer_idx + 1, + [x[:, :, frame_idx:frame_idx + 2, :, :]], feat_cache, feat_idx.copy(), out_chunks, @@ -388,11 +393,6 @@ class Decoder3d(nn.Module): del x return - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - next_x_ref = [x] del x self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) From ebf6b52e322664af91fcdc8b8848d31d5fb98f66 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Mar 2026 22:32:16 -0400 Subject: [PATCH 31/42] ComfyUI v0.18.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index a3b7204dc..61d7672ca 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.18.0" +__version__ = "0.18.1" diff --git a/pyproject.toml b/pyproject.toml index 6db9b1267..1fc9402a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.18.0" +version = "0.18.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From d49420b3c7daf86cae1d7419e37848a974e1b7be Mon Sep 17 00:00:00 2001 From: Talmaj Date: Sun, 22 Mar 2026 04:51:05 +0100 Subject: [PATCH 32/42] LongCat-Image edit (#13003) --- comfy/ldm/flux/model.py | 2 +- comfy/model_base.py | 5 +++-- comfy/text_encoders/llama.py | 11 +++++++++-- comfy/text_encoders/longcat_image.py | 25 ++++++++++++++++++++----- comfy/text_encoders/qwen_vl.py | 3 +++ 5 files changed, 36 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8e7912e6d..2020326c2 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -386,7 +386,7 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) ref_num_tokens.append(kontext.shape[1]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 43ec93324..bfffe2402 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -937,9 +937,10 @@ class LongCatImage(Flux): transformer_options = transformer_options.copy() rope_opts = transformer_options.get("rope_options", {}) rope_opts = dict(rope_opts) + pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 rope_opts.setdefault("shift_t", 1.0) - rope_opts.setdefault("shift_y", 512.0) - rope_opts.setdefault("shift_x", 512.0) + rope_opts.setdefault("shift_y", pe_len) + rope_opts.setdefault("shift_x", pe_len) transformer_options["rope_options"] = rope_opts return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ccc200b7a..9fdea999c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): grid = e.get("extra", None) start = e.get("index") if position_ids is None: - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + if attention_mask is not None: + # Assign compact sequential positions to attended tokens only, + # skipping over padding so post-padding tokens aren't inflated. + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:]) + else: + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py index 882d80901..0962779e3 100644 --- a/comfy/text_encoders/longcat_image.py +++ b/comfy/text_encoders/longcat_image.py @@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): return [output] +IMAGE_PAD_TOKEN_ID = 151655 + class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" + def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__( embedding_directory=embedding_directory, @@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): name="qwen25_7b", tokenizer=LongCatImageBaseTokenizer, ) - self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" - self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs): skip_template = False if text.startswith("<|im_start|>"): skip_template = True @@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): text, return_word_ids=return_word_ids, disable_weights=True, **kwargs ) else: + has_images = images is not None and len(images) > 0 + template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX + prefix_ids = base_tok.tokenizer( - self.longcat_template_prefix, add_special_tokens=False + template_prefix, add_special_tokens=False )["input_ids"] suffix_ids = base_tok.tokenizer( - self.longcat_template_suffix, add_special_tokens=False + self.SUFFIX, add_special_tokens=False )["input_ids"] prompt_tokens = base_tok.tokenize_with_weights( @@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): suffix_pairs = [(t, 1.0) for t in suffix_ids] combined = prefix_pairs + prompt_pairs + suffix_pairs + + if has_images: + embed_count = 0 + for i in range(len(combined)): + if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images): + combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1]) + embed_count += 1 + tokens = {"qwen25_7b": [combined]} return tokens diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 3b18ce730..98c350a12 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module): hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = self.merger(hidden_states) + # Potentially important for spatially precise edits. This is present in the HF implementation. + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] return hidden_states From 6265a239f379f1a5cf2bfdcd3a9631d4c11e50fb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 22 Mar 2026 15:46:18 -0700 Subject: [PATCH 33/42] Add warning for users who disable dynamic vram. (#13113) --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index f99aee38e..cd4483c67 100644 --- a/main.py +++ b/main.py @@ -471,6 +471,9 @@ if __name__ == "__main__": if sys.version_info.major == 3 and sys.version_info.minor < 10: logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") + if args.disable_dynamic_vram: + logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.") + event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() From da6edb5a4e5745869d64ae05b96263da42d5364e Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 24 Mar 2026 01:59:21 +0900 Subject: [PATCH 34/42] bump manager version to 4.1b8 (#13108) --- manager_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manager_requirements.txt b/manager_requirements.txt index 5b06b56f6..90a2be84e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b6 \ No newline at end of file +comfyui_manager==4.1b8 From e87858e9743f92222cdb478f1f835135750b6a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:22:24 +0200 Subject: [PATCH 35/42] feat: LTX2: Support reference audio (ID-LoRA) (#13111) --- comfy/ldm/lightricks/av_model.py | 42 +++++++++++++++++ comfy/model_base.py | 4 ++ comfy_extras/nodes_lt.py | 80 ++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 08d686b7b..6f2ba41ef 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel): additional_args["has_spatial_mask"] = has_spatial_mask ax, a_latent_coords = self.a_patchifier.patchify(ax) + + # Inject reference audio for ID-LoRA in-context conditioning + ref_audio = kwargs.get("ref_audio", None) + ref_audio_seq_len = 0 + if ref_audio is not None: + ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device) + if ref_tokens.shape[0] < ax.shape[0]: + ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1) + ref_audio_seq_len = ref_tokens.shape[1] + B = ax.shape[0] + + # Compute negative temporal positions matching ID-LoRA convention: + # offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0 + p = self.a_patchifier + tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate + ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device) + ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device) + time_offset = ref_end[-1].item() + tpl + ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_pos = torch.stack([ref_start, ref_end], dim=-1) + + additional_args["ref_audio_seq_len"] = ref_audio_seq_len + additional_args["target_audio_seq_len"] = ax.shape[1] + ax = torch.cat([ref_tokens, ax], dim=1) + a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) + ax = self.audio_patchify_proj(ax) # additional_args.update({"av_orig_shape": list(x.shape)}) @@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel): # Prepare audio timestep a_timestep = kwargs.get("a_timestep") + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0 and a_timestep is not None: + # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma. + target_len = kwargs.get("target_audio_seq_len") + if a_timestep.dim() <= 1: + a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) + ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype) + a_timestep = torch.cat([ref_ts, a_timestep], dim=1) if a_timestep is not None: a_timestep_scaled = a_timestep * self.timestep_scale_multiplier a_timestep_flat = a_timestep_scaled.flatten() @@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel): v_embedded_timestep = embedded_timestep[0] a_embedded_timestep = embedded_timestep[1] + # Trim reference audio tokens before unpatchification + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0: + ax = ax[:, ref_audio_seq_len:] + if a_embedded_timestep.shape[1] > 1: + a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:] + # Expand compressed video timestep if needed if isinstance(v_embedded_timestep, CompressedTimestep): v_embedded_timestep = v_embedded_timestep.expand() diff --git a/comfy/model_base.py b/comfy/model_base.py index bfffe2402..70aff886e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1061,6 +1061,10 @@ class LTXAV(BaseModel): if guide_attention_entries is not None: out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + ref_audio = kwargs.get("ref_audio", None) + if ref_audio is not None: + out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index c05571143..d7c2e8744 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.samplers import comfy.utils import math import numpy as np @@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode): return io.NodeOutput(video_latent, audio_latent) +class LTXVReferenceAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVReferenceAudio", + display_name="LTXV Reference Audio (ID-LoRA)", + category="conditioning/audio", + description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."), + io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."), + io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."), + ], + outputs=[ + io.Model.Output(), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: + # Encode reference audio to latents and patchify + audio_latents = audio_vae.encode(reference_audio) + b, c, t, f = audio_latents.shape + ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) + ref_audio = {"tokens": ref_tokens} + + positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio}) + negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio}) + + # Patch model with identity guidance + m = model.clone() + scale = identity_guidance_scale + model_sampling = m.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def post_cfg_function(args): + if scale == 0: + return args["denoised"] + + sigma = args["sigma"] + sigma_ = sigma[0].item() + if sigma_ > sigma_start or sigma_ < sigma_end: + return args["denoised"] + + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + model_options = args["model_options"].copy() + x = args["input"] + + # Strip ref_audio from conditioning for the no-reference pass + noref_cond = [] + for entry in cond: + new_entry = entry.copy() + mc = new_entry.get("model_conds", {}).copy() + mc.pop("ref_audio", None) + new_entry["model_conds"] = mc + noref_cond.append(new_entry) + + (pred_noref,) = comfy.samplers.calc_cond_batch( + args["model"], [noref_cond], x, sigma, model_options + ) + + return cfg_result + (cond_pred - pred_noref) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return io.NodeOutput(m, positive, negative) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension): LTXVCropGuides, LTXVConcatAVLatent, LTXVSeparateAVLatent, + LTXVReferenceAudio, ] From 2d4970ff677970fbca9f9f562296eda46de8aa4c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:43:41 -0700 Subject: [PATCH 36/42] Update frontend version to 1.42.8 (#13126) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ad0344ed4..26cc94354 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.21 +comfyui-frontend-package==1.42.8 comfyui-workflow-templates==0.9.26 comfyui-embedded-docs==0.4.3 torch From 2d5fd3f5dde51574d77601dbe4c163a95a56121a Mon Sep 17 00:00:00 2001 From: Kelly Yang <124ykl@gmail.com> Date: Tue, 24 Mar 2026 11:22:30 -0700 Subject: [PATCH 37/42] fix: set default values of Color Adjustment node to zero (#13084) Co-authored-by: Jedrzej Kosinski --- blueprints/Color Adjustment.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blueprints/Color Adjustment.json b/blueprints/Color Adjustment.json index c599f7213..47f3df783 100644 --- a/blueprints/Color Adjustment.json +++ b/blueprints/Color Adjustment.json @@ -1 +1 @@ -{"revision": 0, "last_node_id": 14, "last_link_id": 0, "nodes": [{"id": 14, "type": "36677b92-5dd8-47a5-9380-4da982c1894f", "pos": [3610, -2630], "size": [270, 150], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"], ["7", "value"], ["6", "value"]]}, "widgets_values": [], "title": "Color Adjustment"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "36677b92-5dd8-47a5-9380-4da982c1894f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 16, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Color Adjustment", "inputNode": {"id": -10, "bounding": [3110, -3560, 120, 60]}, "outputNode": {"id": -20, "bounding": [4070, -3560, 120, 60]}, "inputs": [{"id": "0431d493-5f28-4430-bd00-84733997fc08", "name": "images.image0", "type": "IMAGE", "linkIds": [29], "localized_name": "images.image0", "label": "image", "pos": [3210, -3540]}], "outputs": [{"id": "bee8ea06-a114-4612-8937-939f2c927bdb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4090, -3540]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [3590, -3940], "size": [420, 252], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 29}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 34}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 30}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 31}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 33}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [28]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // temperature (-100 to 100)\nuniform float u_float1; // tint (-100 to 100)\nuniform float u_float2; // vibrance (-100 to 100)\nuniform float u_float3; // saturation (-100 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float INPUT_SCALE = 0.01;\nconst float TEMP_TINT_PRIMARY = 0.3;\nconst float TEMP_TINT_SECONDARY = 0.15;\nconst float VIBRANCE_BOOST = 2.0;\nconst float SATURATION_BOOST = 2.0;\nconst float SKIN_PROTECTION = 0.5;\nconst float EPSILON = 0.001;\nconst vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);\n\nvoid main() {\n vec4 tex = texture(u_image0, v_texCoord);\n vec3 color = tex.rgb;\n \n // Scale inputs: -100/100 \u2192 -1/1\n float temperature = u_float0 * INPUT_SCALE;\n float tint = u_float1 * INPUT_SCALE;\n float vibrance = u_float2 * INPUT_SCALE;\n float saturation = u_float3 * INPUT_SCALE;\n \n // Temperature (warm/cool): positive = warm, negative = cool\n color.r += temperature * TEMP_TINT_PRIMARY;\n color.b -= temperature * TEMP_TINT_PRIMARY;\n \n // Tint (green/magenta): positive = green, negative = magenta\n color.g += tint * TEMP_TINT_PRIMARY;\n color.r -= tint * TEMP_TINT_SECONDARY;\n color.b -= tint * TEMP_TINT_SECONDARY;\n \n // Single clamp after temperature/tint\n color = clamp(color, 0.0, 1.0);\n \n // Vibrance with skin protection\n if (vibrance != 0.0) {\n float maxC = max(color.r, max(color.g, color.b));\n float minC = min(color.r, min(color.g, color.b));\n float sat = maxC - minC;\n float gray = dot(color, LUMA_WEIGHTS);\n \n if (vibrance < 0.0) {\n // Desaturate: -100 \u2192 gray\n color = mix(vec3(gray), color, 1.0 + vibrance);\n } else {\n // Boost less saturated colors more\n float vibranceAmt = vibrance * (1.0 - sat);\n \n // Branchless skin tone protection\n float isWarmTone = step(color.b, color.g) * step(color.g, color.r);\n float warmth = (color.r - color.b) / max(maxC, EPSILON);\n float skinTone = isWarmTone * warmth * sat * (1.0 - sat);\n vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);\n \n color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);\n }\n }\n \n // Saturation\n if (saturation != 0.0) {\n float gray = dot(color, LUMA_WEIGHTS);\n float satMix = saturation < 0.0\n ? 1.0 + saturation // -100 \u2192 gray\n : 1.0 + saturation * SATURATION_BOOST; // +100 \u2192 3x boost\n color = mix(vec3(gray), color, satMix);\n }\n \n fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);\n}", "from_input"]}, {"id": 6, "type": "PrimitiveFloat", "pos": [3290, -3610], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "vibrance", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26, 31]}], "title": "Vibrance", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 7, "type": "PrimitiveFloat", "pos": [3290, -3720], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [33]}], "title": "Saturation", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [3290, -3830], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "tint", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [30]}], "title": "Tint", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 255, 0]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 0, 255]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [3290, -3940], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "temperature", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [34]}], "title": "Temperature", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [68, 136, 255]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 136, 0]}]}, "widgets_values": [100]}], "groups": [], "links": [{"id": 34, "origin_id": 4, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 30, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 31, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 33, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 29, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} +{"revision": 0, "last_node_id": 14, "last_link_id": 0, "nodes": [{"id": 14, "type": "36677b92-5dd8-47a5-9380-4da982c1894f", "pos": [3610, -2630], "size": [270, 150], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"], ["7", "value"], ["6", "value"]]}, "widgets_values": [], "title": "Color Adjustment"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "36677b92-5dd8-47a5-9380-4da982c1894f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 16, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Color Adjustment", "inputNode": {"id": -10, "bounding": [3110, -3560, 120, 60]}, "outputNode": {"id": -20, "bounding": [4070, -3560, 120, 60]}, "inputs": [{"id": "0431d493-5f28-4430-bd00-84733997fc08", "name": "images.image0", "type": "IMAGE", "linkIds": [29], "localized_name": "images.image0", "label": "image", "pos": [3210, -3540]}], "outputs": [{"id": "bee8ea06-a114-4612-8937-939f2c927bdb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4090, -3540]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [3590, -3940], "size": [420, 252], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 29}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 34}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 30}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 31}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 33}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [28]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // temperature (-100 to 100)\nuniform float u_float1; // tint (-100 to 100)\nuniform float u_float2; // vibrance (-100 to 100)\nuniform float u_float3; // saturation (-100 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float INPUT_SCALE = 0.01;\nconst float TEMP_TINT_PRIMARY = 0.3;\nconst float TEMP_TINT_SECONDARY = 0.15;\nconst float VIBRANCE_BOOST = 2.0;\nconst float SATURATION_BOOST = 2.0;\nconst float SKIN_PROTECTION = 0.5;\nconst float EPSILON = 0.001;\nconst vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);\n\nvoid main() {\n vec4 tex = texture(u_image0, v_texCoord);\n vec3 color = tex.rgb;\n \n // Scale inputs: -100/100 \u2192 -1/1\n float temperature = u_float0 * INPUT_SCALE;\n float tint = u_float1 * INPUT_SCALE;\n float vibrance = u_float2 * INPUT_SCALE;\n float saturation = u_float3 * INPUT_SCALE;\n \n // Temperature (warm/cool): positive = warm, negative = cool\n color.r += temperature * TEMP_TINT_PRIMARY;\n color.b -= temperature * TEMP_TINT_PRIMARY;\n \n // Tint (green/magenta): positive = green, negative = magenta\n color.g += tint * TEMP_TINT_PRIMARY;\n color.r -= tint * TEMP_TINT_SECONDARY;\n color.b -= tint * TEMP_TINT_SECONDARY;\n \n // Single clamp after temperature/tint\n color = clamp(color, 0.0, 1.0);\n \n // Vibrance with skin protection\n if (vibrance != 0.0) {\n float maxC = max(color.r, max(color.g, color.b));\n float minC = min(color.r, min(color.g, color.b));\n float sat = maxC - minC;\n float gray = dot(color, LUMA_WEIGHTS);\n \n if (vibrance < 0.0) {\n // Desaturate: -100 \u2192 gray\n color = mix(vec3(gray), color, 1.0 + vibrance);\n } else {\n // Boost less saturated colors more\n float vibranceAmt = vibrance * (1.0 - sat);\n \n // Branchless skin tone protection\n float isWarmTone = step(color.b, color.g) * step(color.g, color.r);\n float warmth = (color.r - color.b) / max(maxC, EPSILON);\n float skinTone = isWarmTone * warmth * sat * (1.0 - sat);\n vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);\n \n color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);\n }\n }\n \n // Saturation\n if (saturation != 0.0) {\n float gray = dot(color, LUMA_WEIGHTS);\n float satMix = saturation < 0.0\n ? 1.0 + saturation // -100 \u2192 gray\n : 1.0 + saturation * SATURATION_BOOST; // +100 \u2192 3x boost\n color = mix(vec3(gray), color, satMix);\n }\n \n fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);\n}", "from_input"]}, {"id": 6, "type": "PrimitiveFloat", "pos": [3290, -3610], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "vibrance", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26, 31]}], "title": "Vibrance", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 7, "type": "PrimitiveFloat", "pos": [3290, -3720], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [33]}], "title": "Saturation", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [3290, -3830], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "tint", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [30]}], "title": "Tint", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 255, 0]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 0, 255]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [3290, -3940], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "temperature", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [34]}], "title": "Temperature", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [68, 136, 255]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 136, 0]}]}, "widgets_values": [0]}], "groups": [], "links": [{"id": 34, "origin_id": 4, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 30, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 31, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 33, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 29, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} From f9ec85f739aeab3fbc0f89baaa1e0fc196f2ff2c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 24 Mar 2026 22:27:39 +0200 Subject: [PATCH 38/42] feat(api-nodes): update xAI Grok nodes (#13140) --- comfy_api_nodes/apis/grok.py | 10 +- comfy_api_nodes/nodes_grok.py | 251 ++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py index c56c8aecc..fbedb53e0 100644 --- a/comfy_api_nodes/apis/grok.py +++ b/comfy_api_nodes/apis/grok.py @@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel): class VideoGenerationRequest(BaseModel): model: str = Field(...) prompt: str = Field(...) - image: InputUrlObject | None = Field(...) + image: InputUrlObject | None = Field(None) + reference_images: list[InputUrlObject] | None = Field(None) duration: int = Field(...) aspect_ratio: str | None = Field(...) resolution: str = Field(...) seed: int = Field(...) +class VideoExtensionRequest(BaseModel): + prompt: str = Field(...) + video: InputUrlObject = Field(...) + duration: int = Field(default=6) + model: str | None = Field(default=None) + + class VideoEditRequest(BaseModel): model: str = Field(...) prompt: str = Field(...) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index 0716d6239..dabc899d6 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import ( ImageGenerationResponse, InputUrlObject, VideoEditRequest, + VideoExtensionRequest, VideoGenerationRequest, VideoGenerationResponse, VideoStatusResponse, @@ -21,6 +22,7 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_images_to_comfyapi, upload_video_to_comfyapi, validate_string, validate_video_duration, @@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None: return None +def _extract_grok_video_price(response) -> float | None: + price = _extract_grok_price(response) + if price is not None: + return price * 1.43 + return None + + class GrokImageNode(IO.ComfyNode): @classmethod @@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: + if model == "grok-imagine-video-beta": + model = "grok-imagine-video" image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.video.url)) +class GrokVideoReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoReferenceNode", + display_name="Grok Reference-to-Video", + category="api node/video/Grok", + description="Generate video guided by reference images as style and content references.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="reference_", + min=1, + max=7, + ), + tooltip="Up to 7 reference images to guide the video generation.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=2, + max=10, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model.duration", "model.resolution"], + input_groups=["model.reference_images"], + ), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $refs := inputGroups["model.reference_images"]; + $rate := $res = "720p" ? 0.07 : 0.05; + $price := ($rate * $dur + 0.002 * $refs) * 1.43; + {"type":"usd","usd": $price} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + ref_image_urls = await upload_images_to_comfyapi( + cls, + list(model["reference_images"].values()), + mime_type="image/png", + wait_label="Uploading base images", + max_images=7, + ) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model["model"], + reference_images=[InputUrlObject(url=i) for i in ref_image_urls], + prompt=prompt, + resolution=model["resolution"], + duration=model["duration"], + aspect_ratio=model["aspect_ratio"], + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoExtendNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoExtendNode", + display_name="Grok Video Extend", + category="api node/video/Grok", + description="Extend an existing video with a seamless continuation based on a text prompt.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of what should happen next in the video.", + ), + IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Int.Input( + "duration", + default=8, + min=2, + max=10, + step=1, + tooltip="Length of the extension in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video extension.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]), + expr=""" + ( + $dur := $lookup(widgets, "model.duration"); + { + "type": "range_usd", + "min_usd": (0.02 + 0.05 * $dur) * 1.43, + "max_usd": (0.15 + 0.05 * $dur) * 1.43 + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + video: Input.Video, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=2, max_duration=15) + video_size = get_fs_object_size(video.get_stream_source()) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"), + data=VideoExtensionRequest( + prompt=prompt, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + duration=model["duration"], + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + class GrokExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension): GrokImageNode, GrokImageEditNode, GrokVideoNode, + GrokVideoReferenceNode, GrokVideoEditNode, + GrokVideoExtendNode, ] From c2862b24af49ff40b251ea2a4e22b92af9e92982 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:36:12 -0700 Subject: [PATCH 39/42] Update templates package version. (#13141) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 26cc94354..76f824906 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.8 -comfyui-workflow-templates==0.9.26 +comfyui-workflow-templates==0.9.36 comfyui-embedded-docs==0.4.3 torch torchsde From 8e73678dae6e5763bc860d6f98566243a494f9c2 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Tue, 24 Mar 2026 17:47:28 -0400 Subject: [PATCH 40/42] CURVE node (#12757) * CURVE node * remove curve to sigmas node * feat: add CurveInput ABC with MonotoneCubicCurve implementation (#12986) CurveInput is an abstract base class so future curve representations (bezier, LUT-based, analytical functions) can be added without breaking downstream nodes that type-check against CurveInput. MonotoneCubicCurve is the concrete implementation that: - Mirrors frontend createMonotoneInterpolator (curveUtils.ts) exactly - Pre-computes slopes as numpy arrays at construction time - Provides vectorised interp_array() using numpy for batch evaluation - interp() for single-value evaluation - to_lut() for generating lookup tables CurveEditor node wraps raw widget points in MonotoneCubicCurve. * linear curve * refactor: move CurveEditor to comfy_extras/nodes_curve.py with V3 schema * feat: add HISTOGRAM type and histogram support to CurveEditor * code improve --------- Co-authored-by: Christian Byrne --- comfy_api/input/__init__.py | 8 + comfy_api/latest/_input/__init__.py | 5 + comfy_api/latest/_input/curve_types.py | 219 +++++++++++++++++++++++++ comfy_api/latest/_io.py | 20 ++- comfy_extras/nodes_curve.py | 42 +++++ nodes.py | 1 + 6 files changed, 292 insertions(+), 3 deletions(-) create mode 100644 comfy_api/latest/_input/curve_types.py create mode 100644 comfy_extras/nodes_curve.py diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 68ff78270..16d4acfd1 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -5,6 +5,10 @@ from comfy_api.latest._input import ( MaskInput, LatentInput, VideoInput, + CurvePoint, + CurveInput, + MonotoneCubicCurve, + LinearCurve, ) __all__ = [ @@ -13,4 +17,8 @@ __all__ = [ "MaskInput", "LatentInput", "VideoInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 14f0e72f4..05cd3d40a 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,4 +1,5 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .video_types import VideoInput __all__ = [ @@ -7,4 +8,8 @@ __all__ = [ "VideoInput", "MaskInput", "LatentInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/curve_types.py b/comfy_api/latest/_input/curve_types.py new file mode 100644 index 000000000..b6dd7adf9 --- /dev/null +++ b/comfy_api/latest/_input/curve_types.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import logging +import math +from abc import ABC, abstractmethod +import numpy as np + +logger = logging.getLogger(__name__) + + +CurvePoint = tuple[float, float] + + +class CurveInput(ABC): + """Abstract base class for curve inputs. + + Subclasses represent different curve representations (control-point + interpolation, analytical functions, LUT-based, etc.) while exposing a + uniform evaluation interface to downstream nodes. + """ + + @property + @abstractmethod + def points(self) -> list[CurvePoint]: + """The control points that define this curve.""" + + @abstractmethod + def interp(self, x: float) -> float: + """Evaluate the curve at a single *x* value in [0, 1].""" + + def interp_array(self, xs: np.ndarray) -> np.ndarray: + """Vectorised evaluation over a numpy array of x values. + + Subclasses should override this for better performance. The default + falls back to scalar ``interp`` calls. + """ + return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs)) + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1].""" + return self.interp_array(np.linspace(0.0, 1.0, size)) + + @staticmethod + def from_raw(data) -> CurveInput: + """Convert raw curve data (dict or point list) to a CurveInput instance. + + Accepts: + - A ``CurveInput`` instance (returned as-is). + - A dict with ``"points"`` and optional ``"interpolation"`` keys. + - A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic). + """ + if isinstance(data, CurveInput): + return data + if isinstance(data, dict): + raw_points = data["points"] + interpolation = data.get("interpolation", "monotone_cubic") + else: + raw_points = data + interpolation = "monotone_cubic" + points = [(float(x), float(y)) for x, y in raw_points] + if interpolation == "linear": + return LinearCurve(points) + if interpolation != "monotone_cubic": + logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation) + return MonotoneCubicCurve(points) + + +class MonotoneCubicCurve(CurveInput): + """Monotone cubic Hermite interpolation over control points. + + Mirrors the frontend ``createMonotoneInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that + backend evaluation matches the editor preview exactly. + + All heavy work (sorting, slope computation) happens once at construction. + ``interp_array`` is fully vectorised with numpy. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + self._slopes = self._compute_slopes() + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def _compute_slopes(self) -> np.ndarray: + xs, ys = self._xs, self._ys + n = len(xs) + if n < 2: + return np.zeros(n, dtype=np.float64) + + dx = np.diff(xs) + dy = np.diff(ys) + dx_safe = np.where(dx == 0, 1.0, dx) + deltas = np.where(dx == 0, 0.0, dy / dx_safe) + + slopes = np.empty(n, dtype=np.float64) + slopes[0] = deltas[0] + slopes[-1] = deltas[-1] + for i in range(1, n - 1): + if deltas[i - 1] * deltas[i] <= 0: + slopes[i] = 0.0 + else: + slopes[i] = (deltas[i - 1] + deltas[i]) / 2 + + for i in range(n - 1): + if deltas[i] == 0: + slopes[i] = 0.0 + slopes[i + 1] = 0.0 + else: + alpha = slopes[i] / deltas[i] + beta = slopes[i + 1] / deltas[i] + s = alpha * alpha + beta * beta + if s > 9: + t = 3 / math.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + return slopes + + def interp(self, x: float) -> float: + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + if x <= xs[0]: + return float(ys[0]) + if x >= xs[-1]: + return float(ys[-1]) + + hi = int(np.searchsorted(xs, x, side='right')) + hi = min(hi, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + if dx == 0: + return float(ys[lo]) + + t = (x - xs[lo]) / dx + t2 = t * t + t3 = t2 * t + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + """Fully vectorised evaluation using numpy.""" + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if n == 1: + return np.full_like(xs_in, ys[0], dtype=np.float64) + + hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + dx_safe = np.where(dx == 0, 1.0, dx) + t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe) + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi] + result = np.where(xs_in <= xs[0], ys[0], result) + result = np.where(xs_in >= xs[-1], ys[-1], result) + return result + + def __repr__(self) -> str: + return f"MonotoneCubicCurve(points={self._points})" + + +class LinearCurve(CurveInput): + """Piecewise linear interpolation over control points. + + Mirrors the frontend ``createLinearInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts``. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def interp(self, x: float) -> float: + xs, ys = self._xs, self._ys + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + return float(np.interp(x, xs, ys)) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + if len(self._xs) == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if len(self._xs) == 1: + return np.full_like(xs_in, self._ys[0], dtype=np.float64) + return np.interp(xs_in, self._xs, self._ys) + + def __repr__(self) -> str: + return f"LinearCurve(points={self._points})" diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 7ca8f4e0c..1cbc8ed26 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput + from comfy_api.input import VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker @@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO): @comfytype(io_type="CURVE") class Curve(ComfyTypeIO): - CurvePoint = tuple[float, float] - Type = list[CurvePoint] + from comfy_api.input import CurvePoint + if TYPE_CHECKING: + Type = CurveInput_ class Input(WidgetInput): def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, @@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO): if default is None: self.default = [(0.0, 0.0), (1.0, 1.0)] + def as_dict(self): + d = super().as_dict() + if self.default is not None: + d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"} + return d + + +@comfytype(io_type="HISTOGRAM") +class Histogram(ComfyTypeIO): + """A histogram represented as a list of bin counts.""" + Type = list[int] + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): @@ -2240,5 +2253,6 @@ __all__ = [ "PriceBadge", "BoundingBox", "Curve", + "Histogram", "NodeReplace", ] diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py new file mode 100644 index 000000000..9016a84f9 --- /dev/null +++ b/comfy_extras/nodes_curve.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from comfy_api.latest import ComfyExtension, io +from comfy_api.input import CurveInput +from typing_extensions import override + + +class CurveEditor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveEditor", + display_name="Curve Editor", + category="utils", + inputs=[ + io.Curve.Input("curve"), + io.Histogram.Input("histogram", optional=True), + ], + outputs=[ + io.Curve.Output("curve"), + ], + ) + + @classmethod + def execute(cls, curve, histogram=None) -> io.NodeOutput: + result = CurveInput.from_raw(curve) + + ui = {} + if histogram is not None: + ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram) + + return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result) + + +class CurveExtension(ComfyExtension): + @override + async def get_node_list(self): + return [CurveEditor] + + +async def comfy_entrypoint(): + return CurveExtension() diff --git a/nodes.py b/nodes.py index 2c4650a20..79874d051 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_sdpose.py", "nodes_math.py", "nodes_painter.py", + "nodes_curve.py", ] import_failed = [] From a0a64c679fca700a087d0cdfa419912a3e8b3bf8 Mon Sep 17 00:00:00 2001 From: Dante Date: Wed, 25 Mar 2026 07:38:08 +0900 Subject: [PATCH 41/42] Add Number Convert node (#13041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Number Convert node for unified numeric type conversion Consolidates fragmented IntToFloat/FloatToInt nodes (previously only available via third-party packs like ComfyMath, FillNodes, etc.) into a single core node. - Single input accepting INT, FLOAT, STRING, and BOOL types - Two outputs: FLOAT and INT - Conversion: bool→0/1, string→parsed number, float↔int standard cast - Follows Math Expression node patterns (comfy_api, io.Schema, etc.) Refs: COM-16925 * Register nodes_number_convert.py in extras_files list Without this entry in nodes.py, the Number Convert node file would not be discovered and loaded at startup. * Add isfinite guard, exception chaining, and unit tests for Number Convert node - Add math.isfinite() check to prevent int() crash on inf/nan string inputs - Use 'from None' for cleaner exception chaining on string parse failure - Add 21 unit tests covering all input types and error paths --- comfy_extras/nodes_number_convert.py | 79 +++++++++++ nodes.py | 1 + .../nodes_number_convert_test.py | 123 ++++++++++++++++++ 3 files changed, 203 insertions(+) create mode 100644 comfy_extras/nodes_number_convert.py create mode 100644 tests-unit/comfy_extras_test/nodes_number_convert_test.py diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py new file mode 100644 index 000000000..b2822c856 --- /dev/null +++ b/comfy_extras/nodes_number_convert.py @@ -0,0 +1,79 @@ +"""Number Convert node for unified numeric type conversion. + +Provides a single node that converts INT, FLOAT, STRING, and BOOL +inputs into FLOAT and INT outputs. +""" + +from __future__ import annotations + +import math + +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class NumberConvertNode(io.ComfyNode): + """Converts various types to numeric FLOAT and INT outputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ComfyNumberConvert", + display_name="Number Convert", + category="math", + search_aliases=[ + "int to float", "float to int", "number convert", + "int2float", "float2int", "cast", "parse number", + "string to number", "bool to int", + ], + inputs=[ + io.MultiType.Input( + "value", + [io.Int, io.Float, io.String, io.Boolean], + display_name="value", + ), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute(cls, value) -> io.NodeOutput: + if isinstance(value, bool): + float_val = 1.0 if value else 0.0 + elif isinstance(value, (int, float)): + float_val = float(value) + elif isinstance(value, str): + text = value.strip() + if not text: + raise ValueError("Cannot convert empty string to number.") + try: + float_val = float(text) + except ValueError: + raise ValueError( + f"Cannot convert string to number: {value!r}" + ) from None + else: + raise TypeError( + f"Unsupported input type: {type(value).__name__}" + ) + + if not math.isfinite(float_val): + raise ValueError( + f"Cannot convert non-finite value to number: {float_val}" + ) + + return io.NodeOutput(float_val, int(float_val)) + + +class NumberConvertExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [NumberConvertNode] + + +async def comfy_entrypoint() -> NumberConvertExtension: + return NumberConvertExtension() diff --git a/nodes.py b/nodes.py index 79874d051..37ceac2fc 100644 --- a/nodes.py +++ b/nodes.py @@ -2454,6 +2454,7 @@ async def init_builtin_extra_nodes(): "nodes_nag.py", "nodes_sdpose.py", "nodes_math.py", + "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", ] diff --git a/tests-unit/comfy_extras_test/nodes_number_convert_test.py b/tests-unit/comfy_extras_test/nodes_number_convert_test.py new file mode 100644 index 000000000..0046fa8f4 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_number_convert_test.py @@ -0,0 +1,123 @@ +import pytest +from unittest.mock import patch, MagicMock + +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 +mock_server = MagicMock() + +with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}): + from comfy_extras.nodes_number_convert import NumberConvertNode + + +class TestNumberConvertExecute: + @staticmethod + def _exec(value) -> object: + return NumberConvertNode.execute(value) + + # --- INT input --- + + def test_int_input(self): + result = self._exec(42) + assert result[0] == 42.0 + assert result[1] == 42 + + def test_int_zero(self): + result = self._exec(0) + assert result[0] == 0.0 + assert result[1] == 0 + + def test_int_negative(self): + result = self._exec(-7) + assert result[0] == -7.0 + assert result[1] == -7 + + # --- FLOAT input --- + + def test_float_input(self): + result = self._exec(3.14) + assert result[0] == 3.14 + assert result[1] == 3 + + def test_float_truncation_toward_zero(self): + result = self._exec(-2.9) + assert result[0] == -2.9 + assert result[1] == -2 # int() truncates toward zero, not floor + + def test_float_output_type(self): + result = self._exec(5) + assert isinstance(result[0], float) + + def test_int_output_type(self): + result = self._exec(5.7) + assert isinstance(result[1], int) + + # --- BOOL input --- + + def test_bool_true(self): + result = self._exec(True) + assert result[0] == 1.0 + assert result[1] == 1 + + def test_bool_false(self): + result = self._exec(False) + assert result[0] == 0.0 + assert result[1] == 0 + + # --- STRING input --- + + def test_string_integer(self): + result = self._exec("42") + assert result[0] == 42.0 + assert result[1] == 42 + + def test_string_float(self): + result = self._exec("3.14") + assert result[0] == 3.14 + assert result[1] == 3 + + def test_string_negative(self): + result = self._exec("-5.5") + assert result[0] == -5.5 + assert result[1] == -5 + + def test_string_with_whitespace(self): + result = self._exec(" 7.0 ") + assert result[0] == 7.0 + assert result[1] == 7 + + def test_string_scientific_notation(self): + result = self._exec("1e3") + assert result[0] == 1000.0 + assert result[1] == 1000 + + # --- STRING error paths --- + + def test_empty_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert empty string"): + self._exec("") + + def test_whitespace_only_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert empty string"): + self._exec(" ") + + def test_non_numeric_string_raises(self): + with pytest.raises(ValueError, match="Cannot convert string to number"): + self._exec("abc") + + def test_string_inf_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("inf") + + def test_string_nan_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("nan") + + def test_string_negative_inf_raises(self): + with pytest.raises(ValueError, match="non-finite"): + self._exec("-inf") + + # --- Unsupported type --- + + def test_unsupported_type_raises(self): + with pytest.raises(TypeError, match="Unsupported input type"): + self._exec([1, 2, 3]) From 5ebb0c2e0b72945c271a2fb4db749585aa32a13c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Mar 2026 08:39:04 +0800 Subject: [PATCH 42/42] FP8 bwd training (#13121) --- comfy/model_management.py | 1 + comfy/ops.py | 65 ++++++++++++++++++++++++++++--------- comfy_extras/nodes_train.py | 9 +++++ 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc..9617d8388 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -55,6 +55,7 @@ total_vram = 0 # Training Related State in_training = False +training_fp8_bwd = False def get_supported_float8_types(): diff --git a/comfy/ops.py b/comfy/ops.py index 1518ec9de..ca25693db 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -777,8 +777,16 @@ from .quant_ops import ( class QuantLinearFunc(torch.autograd.Function): - """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. - Handles any input rank by flattening to 2D for matmul and restoring shape after. + """Custom autograd function for quantized linear: quantized forward, optionally FP8 backward. + + When training_fp8_bwd is enabled: + - Forward: quantize input per layout (FP8/NVFP4), use quantized matmul + - Backward: all matmuls use FP8 tensor cores via torch.mm dispatch + - Cached input is FP8 (half the memory of bf16) + + When training_fp8_bwd is disabled: + - Forward: quantize input per layout, use quantized matmul + - Backward: dequantize weight to compute_dtype, use standard matmul """ @staticmethod @@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function): input_shape = input_float.shape inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D - # Quantize input (same as inference path) + # Quantize input for forward (same layout as weight) if layout_type is not None: q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) else: @@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function): output = torch.nn.functional.linear(q_input, w, b) - # Restore original input shape + # Unflatten output to match original input shape if len(input_shape) > 2: output = output.unflatten(0, input_shape[:-1]) - ctx.save_for_backward(input_float, weight) + # Save for backward ctx.input_shape = input_shape ctx.has_bias = bias is not None ctx.compute_dtype = compute_dtype ctx.weight_requires_grad = weight.requires_grad + ctx.fp8_bwd = comfy.model_management.training_fp8_bwd + + if ctx.fp8_bwd: + # Cache FP8 quantized input — half the memory of bf16 + if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'): + ctx.q_input = q_input # already FP8, reuse + else: + # NVFP4 or other layout — quantize input to FP8 for backward + ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + ctx.save_for_backward(weight) + else: + ctx.q_input = None + ctx.save_for_backward(input_float, weight) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): - input_float, weight = ctx.saved_tensors compute_dtype = ctx.compute_dtype grad_2d = grad_output.flatten(0, -2).to(compute_dtype) - # Dequantize weight to compute dtype for backward matmul - if isinstance(weight, QuantizedTensor): - weight_f = weight.dequantize().to(compute_dtype) + # Value casting — only difference between fp8 and non-fp8 paths + if ctx.fp8_bwd: + weight, = ctx.saved_tensors + # Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm + grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout") + if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"): + weight_mm = weight + elif isinstance(weight, QuantizedTensor): + weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout") + else: + weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout") + input_mm = ctx.q_input else: - weight_f = weight.to(compute_dtype) + input_float, weight = ctx.saved_tensors + # Standard tensors → torch.mm does regular matmul + grad_mm = grad_2d + if isinstance(weight, QuantizedTensor): + weight_mm = weight.dequantize().to(compute_dtype) + else: + weight_mm = weight.to(compute_dtype) + input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None - # grad_input = grad_output @ weight - grad_input = torch.mm(grad_2d, weight_f) + # Computation — same for both paths, dispatch handles the rest + grad_input = torch.mm(grad_mm, weight_mm) if len(ctx.input_shape) > 2: grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) - # grad_weight (only if weight requires grad, typically frozen for quantized training) grad_weight = None if ctx.weight_requires_grad: - input_f = input_float.flatten(0, -2).to(compute_dtype) - grad_weight = torch.mm(grad_2d.t(), input_f) + grad_weight = torch.mm(grad_mm.t(), input_mm) - # grad_bias grad_bias = None if ctx.has_bias: grad_bias = grad_2d.sum(dim=0) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 0ad0acee6..df1b39fd5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode): default="bf16", tooltip="The dtype to use for lora.", ), + io.Boolean.Input( + "quantized_backward", + default=False, + tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.", + ), io.Combo.Input( "algorithm", options=list(adapter_maps.keys()), @@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode): seed, training_dtype, lora_dtype, + quantized_backward, algorithm, gradient_checkpointing, checkpoint_depth, @@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode): seed = seed[0] training_dtype = training_dtype[0] lora_dtype = lora_dtype[0] + quantized_backward = quantized_backward[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] offloading = offloading[0] @@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode): bucket_mode = bucket_mode[0] bypass_mode = bypass_mode[0] + comfy.model_management.training_fp8_bwd = quantized_backward + # Process latents based on mode if bucket_mode: latents = _process_latents_bucket_mode(latents)