bug fixes for siglip2 to work

This commit is contained in:
Yousef Rafat 2025-10-12 00:04:43 +03:00
parent 89fc51fb91
commit 4908e7412e
4 changed files with 26 additions and 8 deletions

View File

@ -17,7 +17,7 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"):
image = image[:, :, :, :3] if image.shape[3] > 3 else image image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype) mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype)
@ -29,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
else: else:
scale_size = (size, size) scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) image = torch.nn.functional.interpolate(image, size=scale_size, mode=resize_mode, antialias=True)
h = (image.shape[2] - size)//2 h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2 w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size] image = image[:,:,h:h+size,w:w+size]
@ -71,9 +71,9 @@ class ClipVisionModel():
def get_sd(self): def get_sd(self):
return self.model.state_dict() return self.model.state_dict()
def encode_image(self, image, crop=True): def encode_image(self, image, crop=True, resize_mode = "bicubic"):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, resize_mode=resize_mode).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output() outputs = Output()

View File

@ -1172,6 +1172,16 @@ class MultiheadAttentionComfyv(nn.Module):
k = self._k_proj(src) k = self._k_proj(src)
if v is None: if v is None:
v = self._v_proj(src) v = self._v_proj(src)
k, v = k.to(src.device).to(src.dtype), v.to(src.device).to(src.dtype)
if k is v:
if q is k:
q = k = v = q.transpose(1, 0)
else:
q, k = (x.transpose(1, 0) for x in (q, k))
v = k
else:
q, k, v = (x.transpose(1, 0) for x in (q, k, v))
output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask) output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask)
return self.out_proj(output) return self.out_proj(output)

View File

@ -510,7 +510,9 @@ class VAE:
# TODO # TODO
encode_layers = 25 encode_layers = 25
decode_layers = 4 decode_layers = 4
self.not_video = True self.downscale_ratio = 1
self.upscale_ratio = 1
self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers
self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers

View File

@ -6,6 +6,7 @@ import av
import torch import torch
import folder_paths import folder_paths
import json import json
import logging
from typing import Optional from typing import Optional
from typing_extensions import override from typing_extensions import override
from fractions import Fraction from fractions import Fraction
@ -93,11 +94,13 @@ class EncodeVideo(io.ComfyNode):
chunk = chunk.to(model_dtype) chunk = chunk.to(model_dtype)
if hasattr(vae, "encode"): if hasattr(vae, "encode"):
try: try:
chunk = chunk.movedim(1, -1)
out = vae.encode(chunk) out = vae.encode(chunk)
except: except:
out = model.encode(chunk) out = model.encode(chunk)
else: else:
out = vae.encode_image(chunk, crop=False) chunk = chunk.movedim(1, -1)
out = vae.encode_image(chunk, crop=False, resize_mode="bilinear")
out = out["image_embeds"] out = out["image_embeds"]
out_cpu = out.cpu() out_cpu = out.cpu()
@ -137,9 +140,12 @@ class ResampleVideo(io.ComfyNode):
src_rate = stream.average_rate or stream.guessed_rate src_rate = stream.average_rate or stream.guessed_rate
src_fps = float(src_rate) if src_rate else None src_fps = float(src_rate) if src_rate else None
if src_fps is None:
logging.warning("src_fps for video resampling is None.")
# yield original frames if asked for upsampling or src is unknown # yield original frames if asked for upsampling
if src_fps is None or target_fps > src_fps: if target_fps > src_fps:
for packet in container.demux(stream): for packet in container.demux(stream):
for frame in packet.decode(): for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()