From 4908e7412e44ae78e67bdf59951f5a22178b121b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 12 Oct 2025 00:04:43 +0300 Subject: [PATCH] bug fixes for siglip2 to work --- comfy/clip_vision.py | 8 ++++---- comfy/ldm/modules/attention.py | 10 ++++++++++ comfy/sd.py | 4 +++- comfy_extras/nodes_video.py | 12 +++++++++--- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a8127d18e..78843da47 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -17,7 +17,7 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"): image = image[:, :, :, :3] if image.shape[3] > 3 else image mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) @@ -29,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = torch.nn.functional.interpolate(image, size=scale_size, mode=resize_mode, antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] @@ -71,9 +71,9 @@ class ClipVisionModel(): def get_sd(self): return self.model.state_dict() - def encode_image(self, image, crop=True): + def encode_image(self, image, crop=True, resize_mode = "bicubic"): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, resize_mode=resize_mode).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3706f4344..939c63571 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1172,6 +1172,16 @@ class MultiheadAttentionComfyv(nn.Module): k = self._k_proj(src) if v is None: v = self._v_proj(src) + k, v = k.to(src.device).to(src.dtype), v.to(src.device).to(src.dtype) + + if k is v: + if q is k: + q = k = v = q.transpose(1, 0) + else: + q, k = (x.transpose(1, 0) for x in (q, k)) + v = k + else: + q, k, v = (x.transpose(1, 0) for x in (q, k, v)) output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask) return self.out_proj(output) diff --git a/comfy/sd.py b/comfy/sd.py index 28fc45e41..752bcc785 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -510,7 +510,9 @@ class VAE: # TODO encode_layers = 25 decode_layers = 4 - self.not_video = True + self.downscale_ratio = 1 + self.upscale_ratio = 1 + self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 86632d82b..3702cb659 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -6,6 +6,7 @@ import av import torch import folder_paths import json +import logging from typing import Optional from typing_extensions import override from fractions import Fraction @@ -93,11 +94,13 @@ class EncodeVideo(io.ComfyNode): chunk = chunk.to(model_dtype) if hasattr(vae, "encode"): try: + chunk = chunk.movedim(1, -1) out = vae.encode(chunk) except: out = model.encode(chunk) else: - out = vae.encode_image(chunk, crop=False) + chunk = chunk.movedim(1, -1) + out = vae.encode_image(chunk, crop=False, resize_mode="bilinear") out = out["image_embeds"] out_cpu = out.cpu() @@ -137,9 +140,12 @@ class ResampleVideo(io.ComfyNode): src_rate = stream.average_rate or stream.guessed_rate src_fps = float(src_rate) if src_rate else None + + if src_fps is None: + logging.warning("src_fps for video resampling is None.") - # yield original frames if asked for upsampling or src is unknown - if src_fps is None or target_fps > src_fps: + # yield original frames if asked for upsampling + if target_fps > src_fps: for packet in container.demux(stream): for frame in packet.decode(): arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()