Support the siglip 2 naflex model as a clip vision model. (#11831)

Not useful yet.
This commit is contained in:
comfyanonymous 2026-01-12 14:05:54 -08:00 committed by GitHub
parent a3b5d4996a
commit c881a1d689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 91 additions and 10 deletions

View File

@ -1,6 +1,7 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
import math
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
def scale_dim(size, scale):
scaled = math.ceil(size * scale / patch_size) * patch_size
return max(patch_size, int(scaled))
# Binary search for optimal scale
lo, hi = eps / 10, 100.0
while hi - lo >= eps:
mid = (lo + hi) / 2
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
if (h // patch_size) * (w // patch_size) <= max_num_patches:
lo = mid
else:
hi = mid
return scale_dim(oh, lo), scale_dim(ow, lo)
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
if size > 0:
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
b, c, h, w = image.shape
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
return embeds + embed_weight
class Siglip2Embeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
super().__init__()
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
self.patch_size = patch_size
def forward(self, pixel_values):
b, c, h, w = pixel_values.shape
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
img = img.permute(0, 1, 3, 2, 4, 5)
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
img = self.patch_embedding(img)
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
if model_type in ["siglip2_vision_model"]:
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
else:
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:

View File

@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
@ -32,9 +33,10 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.model_type = config.get("model_type", "clip_vision_model")
self.config = config.copy()
model_class = IMAGE_ENCODERS.get(self.model_type)
if self.model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
@ -55,7 +57,10 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
if self.model_type == "siglip2_vision_model":
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
else:
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
if len(patch_embedding_shape) == 2:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
else:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")

View File

@ -0,0 +1,14 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": -1,
"intermediate_size": 4304,
"model_type": "siglip2_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"num_patches": 256,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}