mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
from .utils import load_torch_file
|
|
import os
|
|
import json
|
|
import torch
|
|
import logging
|
|
|
|
import comfy.ops
|
|
import comfy.model_patcher
|
|
import comfy.model_management
|
|
import comfy.clip_model
|
|
import comfy.background_removal.birefnet
|
|
|
|
BG_REMOVAL_MODELS = {
|
|
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
|
}
|
|
|
|
class BackgroundRemovalModel():
|
|
def __init__(self, json_config):
|
|
with open(json_config) as f:
|
|
config = json.load(f)
|
|
|
|
self.image_size = config.get("image_size", 1024)
|
|
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
|
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
|
self.model_type = config.get("model_type", "birefnet")
|
|
self.config = config.copy()
|
|
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
|
|
|
self.load_device = comfy.model_management.text_encoder_device()
|
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
|
self.model.eval()
|
|
|
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
|
|
|
def load_sd(self, sd):
|
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
|
|
|
def get_sd(self):
|
|
return self.model.state_dict()
|
|
|
|
def encode_image(self, image):
|
|
comfy.model_management.load_model_gpu(self.patcher)
|
|
H, W = image.shape[1], image.shape[2]
|
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
|
out = self.model(pixel_values=pixel_values)
|
|
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
|
|
|
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
|
if mask.ndim == 3:
|
|
mask = mask.unsqueeze(0)
|
|
if mask.shape[1] != 1:
|
|
mask = mask.movedim(-1, 1)
|
|
|
|
return mask
|
|
|
|
|
|
def load_background_removal_model(sd):
|
|
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
|
else:
|
|
return None
|
|
|
|
bg_model = BackgroundRemovalModel(json_config)
|
|
m, u = bg_model.load_sd(sd)
|
|
if len(m) > 0:
|
|
logging.warning("missing background removal: {}".format(m))
|
|
u = set(u)
|
|
keys = list(sd.keys())
|
|
for k in keys:
|
|
if k not in u:
|
|
sd.pop(k)
|
|
return bg_model
|
|
|
|
def load(ckpt_path):
|
|
sd = load_torch_file(ckpt_path)
|
|
return load_background_removal_model(sd)
|