mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 11:07:24 +08:00
revamp
This commit is contained in:
parent
0faba8740b
commit
553d28969a
@ -674,8 +674,7 @@ class Decoder(nn.Module):
|
|||||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||||
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
p1_out = self.conv_out1(_p1)
|
p1_out = self.conv_out1(_p1)
|
||||||
fake = torch.empty_like(p1_out)
|
return p1_out
|
||||||
return p1_out, fake, fake, fake
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleConvs(nn.Module):
|
class SimpleConvs(nn.Module):
|
||||||
78
comfy/bg_removal_model.py
Normal file
78
comfy/bg_removal_model.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from .utils import load_torch_file
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.clip_model
|
||||||
|
import comfy.background_removal.birefnet
|
||||||
|
|
||||||
|
BG_REMOVAL_MODELS = {
|
||||||
|
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
||||||
|
}
|
||||||
|
|
||||||
|
class BackgroundRemovalModel():
|
||||||
|
def __init__(self, json_config):
|
||||||
|
with open(json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
self.image_size = config.get("image_size", 1024)
|
||||||
|
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
||||||
|
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
||||||
|
self.model_type = config.get("model_type", "birefnet")
|
||||||
|
self.config = config.copy()
|
||||||
|
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
||||||
|
|
||||||
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
H, W = image.shape[1], image.shape[2]
|
||||||
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||||
|
out = self.model(pixel_values=pixel_values)
|
||||||
|
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||||
|
|
||||||
|
mask = out.sigmoid()
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
if mask.shape[1] != 1:
|
||||||
|
mask = mask.movedim(-1, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def load_background_removal_model(sd):
|
||||||
|
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bg_model = BackgroundRemovalModel(json_config)
|
||||||
|
m, u = bg_model.load_sd(sd)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("missing background removal: {}".format(m))
|
||||||
|
u = set(u)
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if k not in u:
|
||||||
|
sd.pop(k)
|
||||||
|
return bg_model
|
||||||
|
|
||||||
|
def load(ckpt_path):
|
||||||
|
sd = load_torch_file(ckpt_path)
|
||||||
|
return load_background_removal_model(sd)
|
||||||
@ -10,7 +10,6 @@ import comfy.model_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
import comfy.image_encoders.dino2
|
import comfy.image_encoders.dino2
|
||||||
import comfy.image_encoders.birefnet
|
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@ -25,7 +24,6 @@ IMAGE_ENCODERS = {
|
|||||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
"birefnet": comfy.image_encoders.birefnet.BiRefNet
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
@ -37,7 +35,6 @@ class ClipVisionModel():
|
|||||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
self.model_type = config.get("model_type", "clip_vision_model")
|
self.model_type = config.get("model_type", "clip_vision_model")
|
||||||
self.resize_to_original = config.get("resize_to_original", False)
|
|
||||||
self.config = config.copy()
|
self.config = config.copy()
|
||||||
model_class = IMAGE_ENCODERS.get(self.model_type)
|
model_class = IMAGE_ENCODERS.get(self.model_type)
|
||||||
if self.model_type == "siglip_vision_model":
|
if self.model_type == "siglip_vision_model":
|
||||||
@ -61,15 +58,11 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image, crop=True):
|
def encode_image(self, image, crop=True):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
H, W = image.shape[1], image.shape[2]
|
|
||||||
if self.model_type == "siglip2_vision_model":
|
if self.model_type == "siglip2_vision_model":
|
||||||
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
else:
|
else:
|
||||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||||
if self.resize_to_original:
|
|
||||||
resized = torch.nn.functional.interpolate(out[0], size=(H, W), mode="bicubic", antialias=False)
|
|
||||||
out = (resized,) + out[1:]
|
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
|
||||||
@ -137,9 +130,6 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
|
||||||
elif "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "birefnet.json")
|
|
||||||
|
|
||||||
# Dinov2
|
# Dinov2
|
||||||
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
|
|||||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from spandrel import ImageModelDescriptor
|
from spandrel import ImageModelDescriptor
|
||||||
from comfy.clip_vision import ClipVisionModel
|
from comfy.clip_vision import ClipVisionModel
|
||||||
from comfy.clip_vision import Output as ClipVisionOutput_
|
from comfy.clip_vision import Output as ClipVisionOutput_
|
||||||
|
from comfy.bg_removal_model import BackgroundRemovalModel
|
||||||
from comfy.controlnet import ControlNet
|
from comfy.controlnet import ControlNet
|
||||||
from comfy.hooks import HookGroup, HookKeyframeGroup
|
from comfy.hooks import HookGroup, HookKeyframeGroup
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
@ -613,6 +614,11 @@ class Model(ComfyTypeIO):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Type = ModelPatcher
|
Type = ModelPatcher
|
||||||
|
|
||||||
|
@comfytype(io_type="BACKGROUND_REMOVAL")
|
||||||
|
class BackgroundRemoval(ComfyTypeIO):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Type = BackgroundRemovalModel
|
||||||
|
|
||||||
@comfytype(io_type="CLIP_VISION")
|
@comfytype(io_type="CLIP_VISION")
|
||||||
class ClipVision(ComfyTypeIO):
|
class ClipVision(ComfyTypeIO):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -2219,6 +2225,7 @@ __all__ = [
|
|||||||
"ModelPatch",
|
"ModelPatch",
|
||||||
"ClipVision",
|
"ClipVision",
|
||||||
"ClipVisionOutput",
|
"ClipVisionOutput",
|
||||||
|
"BackgroundRemoval",
|
||||||
"AudioEncoder",
|
"AudioEncoder",
|
||||||
"AudioEncoderOutput",
|
"AudioEncoderOutput",
|
||||||
"StyleModel",
|
"StyleModel",
|
||||||
|
|||||||
58
comfy_extras/nodes_bg_removal.py
Normal file
58
comfy_extras/nodes_bg_removal.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import folder_paths
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
from comfy.bg_removal_model import load
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBackGroundRemovalModel(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
files = folder_paths.get_filename_list("background_removal")
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LoadBackGroundRemovalModel",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("background_removal_name", options=sorted(files)),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.BackgroundRemoval.Output("bg_model")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, background_removal_name):
|
||||||
|
path = folder_paths.get_full_path_or_raise("background_removal", background_removal_name)
|
||||||
|
bg = load(path)
|
||||||
|
if bg is None:
|
||||||
|
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
||||||
|
return IO.NodeOutput(bg)
|
||||||
|
|
||||||
|
class RemoveBackGround(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RemoveBackGround",
|
||||||
|
category="encode",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
IO.BackgroundRemoval.Input("bg_removal_model")
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Mask.Output("mask")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image, bg_removal_model):
|
||||||
|
mask = bg_removal_model.encode_image(image)
|
||||||
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
|
class BackGroundRemovalExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoadBackGroundRemovalModel,
|
||||||
|
RemoveBackGround
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> BackGroundRemovalExtension:
|
||||||
|
return BackGroundRemovalExtension()
|
||||||
@ -390,29 +390,6 @@ class GrowMask(IO.ComfyNode):
|
|||||||
|
|
||||||
expand_mask = execute # TODO: remove
|
expand_mask = execute # TODO: remove
|
||||||
|
|
||||||
class ClipVisionToMask(IO.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="ClipVisionToMask",
|
|
||||||
inputs = [
|
|
||||||
IO.ClipVisionOutput.Input("clip_vision_output")
|
|
||||||
],
|
|
||||||
outputs = [IO.Mask.Output("mask")]
|
|
||||||
)
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, clip_vision_output):
|
|
||||||
if not isinstance(clip_vision_output, torch.Tensor):
|
|
||||||
mask = clip_vision_output["last_hidden_state"]
|
|
||||||
mask = mask.sigmoid()
|
|
||||||
if mask.ndim == 3:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
if mask.shape[1] != 1:
|
|
||||||
mask = mask.movedim(-1, 1)
|
|
||||||
return IO.NodeOutput(mask)
|
|
||||||
|
|
||||||
clip_vision_to_mask = execute
|
|
||||||
|
|
||||||
class ThresholdMask(IO.ComfyNode):
|
class ThresholdMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -476,7 +453,6 @@ class MaskExtension(ComfyExtension):
|
|||||||
GrowMask,
|
GrowMask,
|
||||||
ThresholdMask,
|
ThresholdMask,
|
||||||
MaskPreview,
|
MaskPreview,
|
||||||
ClipVisionToMask
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
|
|||||||
|
|
||||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||||
|
|
||||||
|
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
|
||||||
|
|
||||||
output_directory = os.path.join(base_path, "output")
|
output_directory = os.path.join(base_path, "output")
|
||||||
temp_directory = os.path.join(base_path, "temp")
|
temp_directory = os.path.join(base_path, "temp")
|
||||||
input_directory = os.path.join(base_path, "input")
|
input_directory = os.path.join(base_path, "input")
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_number_convert.py",
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
"nodes_curve.py",
|
"nodes_curve.py",
|
||||||
"nodes_rtdetr.py"
|
"nodes_rtdetr.py",
|
||||||
|
"nodes_bg_removal.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user