mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-03 18:50:51 +08:00
Refactor: move clip_preprocess to comfy.clip_model (#11586)
This commit is contained in:
parent
236b9e211d
commit
d622a61874
@ -2,6 +2,25 @@ import torch
|
|||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||||
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
|
image = image.movedim(-1, 1)
|
||||||
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
|
if crop:
|
||||||
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
|
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||||
|
else:
|
||||||
|
scale_size = (size, size)
|
||||||
|
|
||||||
|
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||||
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class CLIPAttention(torch.nn.Module):
|
class CLIPAttention(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -17,24 +16,7 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
|
||||||
image = image.movedim(-1, 1)
|
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
|
||||||
if crop:
|
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
|
||||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
|
||||||
else:
|
|
||||||
scale_size = (size, size)
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
|
||||||
h = (image.shape[2] - size)//2
|
|
||||||
w = (image.shape[3] - size)//2
|
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
|
||||||
|
|
||||||
IMAGE_ENCODERS = {
|
IMAGE_ENCODERS = {
|
||||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user