mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
bug fixes for siglip2 to work
This commit is contained in:
parent
89fc51fb91
commit
4908e7412e
@ -17,7 +17,7 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"):
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
@ -29,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
|||||||
else:
|
else:
|
||||||
scale_size = (size, size)
|
scale_size = (size, size)
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode=resize_mode, antialias=True)
|
||||||
h = (image.shape[2] - size)//2
|
h = (image.shape[2] - size)//2
|
||||||
w = (image.shape[3] - size)//2
|
w = (image.shape[3] - size)//2
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
@ -71,9 +71,9 @@ class ClipVisionModel():
|
|||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True, resize_mode = "bicubic"):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop, resize_mode=resize_mode).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
|
|||||||
@ -1172,6 +1172,16 @@ class MultiheadAttentionComfyv(nn.Module):
|
|||||||
k = self._k_proj(src)
|
k = self._k_proj(src)
|
||||||
if v is None:
|
if v is None:
|
||||||
v = self._v_proj(src)
|
v = self._v_proj(src)
|
||||||
|
k, v = k.to(src.device).to(src.dtype), v.to(src.device).to(src.dtype)
|
||||||
|
|
||||||
|
if k is v:
|
||||||
|
if q is k:
|
||||||
|
q = k = v = q.transpose(1, 0)
|
||||||
|
else:
|
||||||
|
q, k = (x.transpose(1, 0) for x in (q, k))
|
||||||
|
v = k
|
||||||
|
else:
|
||||||
|
q, k, v = (x.transpose(1, 0) for x in (q, k, v))
|
||||||
|
|
||||||
output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask)
|
output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask)
|
||||||
return self.out_proj(output)
|
return self.out_proj(output)
|
||||||
|
|||||||
@ -510,7 +510,9 @@ class VAE:
|
|||||||
# TODO
|
# TODO
|
||||||
encode_layers = 25
|
encode_layers = 25
|
||||||
decode_layers = 4
|
decode_layers = 4
|
||||||
self.not_video = True
|
self.downscale_ratio = 1
|
||||||
|
self.upscale_ratio = 1
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers
|
self.memory_used_encode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * encode_layers
|
||||||
self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers
|
self.memory_used_decode = lambda shape, dtype: math.prod(shape) * model_management.dtype_size(dtype) * decode_layers
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import av
|
|||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
@ -93,11 +94,13 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
chunk = chunk.to(model_dtype)
|
chunk = chunk.to(model_dtype)
|
||||||
if hasattr(vae, "encode"):
|
if hasattr(vae, "encode"):
|
||||||
try:
|
try:
|
||||||
|
chunk = chunk.movedim(1, -1)
|
||||||
out = vae.encode(chunk)
|
out = vae.encode(chunk)
|
||||||
except:
|
except:
|
||||||
out = model.encode(chunk)
|
out = model.encode(chunk)
|
||||||
else:
|
else:
|
||||||
out = vae.encode_image(chunk, crop=False)
|
chunk = chunk.movedim(1, -1)
|
||||||
|
out = vae.encode_image(chunk, crop=False, resize_mode="bilinear")
|
||||||
out = out["image_embeds"]
|
out = out["image_embeds"]
|
||||||
|
|
||||||
out_cpu = out.cpu()
|
out_cpu = out.cpu()
|
||||||
@ -137,9 +140,12 @@ class ResampleVideo(io.ComfyNode):
|
|||||||
|
|
||||||
src_rate = stream.average_rate or stream.guessed_rate
|
src_rate = stream.average_rate or stream.guessed_rate
|
||||||
src_fps = float(src_rate) if src_rate else None
|
src_fps = float(src_rate) if src_rate else None
|
||||||
|
|
||||||
|
if src_fps is None:
|
||||||
|
logging.warning("src_fps for video resampling is None.")
|
||||||
|
|
||||||
# yield original frames if asked for upsampling or src is unknown
|
# yield original frames if asked for upsampling
|
||||||
if src_fps is None or target_fps > src_fps:
|
if target_fps > src_fps:
|
||||||
for packet in container.demux(stream):
|
for packet in container.demux(stream):
|
||||||
for frame in packet.decode():
|
for frame in packet.decode():
|
||||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
|
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user