From 553d28969adc7b7a6d895a3a8d9b1d25390ee7c0 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 7 May 2026 23:30:03 +0300 Subject: [PATCH] revamp --- .../birefnet.json | 0 .../birefnet.py | 3 +- comfy/bg_removal_model.py | 78 +++++++++++++++++++ comfy/clip_vision.py | 10 --- comfy_api/latest/_io.py | 7 ++ comfy_extras/nodes_bg_removal.py | 58 ++++++++++++++ comfy_extras/nodes_mask.py | 24 ------ folder_paths.py | 2 + .../put_background_removal_models_here | 0 nodes.py | 3 +- 10 files changed, 148 insertions(+), 37 deletions(-) rename comfy/{image_encoders => background_removal}/birefnet.json (100%) rename comfy/{image_encoders => background_removal}/birefnet.py (99%) create mode 100644 comfy/bg_removal_model.py create mode 100644 comfy_extras/nodes_bg_removal.py create mode 100644 models/background_removal/put_background_removal_models_here diff --git a/comfy/image_encoders/birefnet.json b/comfy/background_removal/birefnet.json similarity index 100% rename from comfy/image_encoders/birefnet.json rename to comfy/background_removal/birefnet.json diff --git a/comfy/image_encoders/birefnet.py b/comfy/background_removal/birefnet.py similarity index 99% rename from comfy/image_encoders/birefnet.py rename to comfy/background_removal/birefnet.py index 25ca5b57e..df54b2b90 100644 --- a/comfy/image_encoders/birefnet.py +++ b/comfy/background_removal/birefnet.py @@ -674,8 +674,7 @@ class Decoder(nn.Module): patches_batch = self.get_patches_batch(x, _p1) if self.split else x _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) p1_out = self.conv_out1(_p1) - fake = torch.empty_like(p1_out) - return p1_out, fake, fake, fake + return p1_out class SimpleConvs(nn.Module): diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py new file mode 100644 index 000000000..cb7c2ee53 --- /dev/null +++ b/comfy/bg_removal_model.py @@ -0,0 +1,78 @@ +from .utils import load_torch_file +import os +import json +import torch +import logging + +import comfy.ops +import comfy.model_patcher +import comfy.model_management +import comfy.clip_model +import comfy.background_removal.birefnet + +BG_REMOVAL_MODELS = { + "birefnet": comfy.background_removal.birefnet.BiRefNet +} + +class BackgroundRemovalModel(): + def __init__(self, json_config): + with open(json_config) as f: + config = json.load(f) + + self.image_size = config.get("image_size", 1024) + self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0]) + self.image_std = config.get("image_std", [1.0, 1.0, 1.0]) + self.model_type = config.get("model_type", "birefnet") + self.config = config.copy() + model_class = BG_REMOVAL_MODELS.get(self.model_type) + + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) + self.model.eval() + + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + + def get_sd(self): + return self.model.state_dict() + + def encode_image(self, image): + comfy.model_management.load_model_gpu(self.patcher) + H, W = image.shape[1], image.shape[2] + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False) + out = self.model(pixel_values=pixel_values) + out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) + + mask = out.sigmoid() + if mask.ndim == 3: + mask = mask.unsqueeze(0) + if mask.shape[1] != 1: + mask = mask.movedim(-1, 1) + + return mask + + +def load_background_removal_model(sd): + if "bb.layers.1.blocks.0.attn.relative_position_index" in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json") + else: + return None + + bg_model = BackgroundRemovalModel(json_config) + m, u = bg_model.load_sd(sd) + if len(m) > 0: + logging.warning("missing background removal: {}".format(m)) + u = set(u) + keys = list(sd.keys()) + for k in keys: + if k not in u: + sd.pop(k) + return bg_model + +def load(ckpt_path): + sd = load_torch_file(ckpt_path) + return load_background_removal_model(sd) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1b5076fc9..f132526f1 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -10,7 +10,6 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 -import comfy.image_encoders.birefnet class Output: def __getitem__(self, key): @@ -25,7 +24,6 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, - "birefnet": comfy.image_encoders.birefnet.BiRefNet } class ClipVisionModel(): @@ -37,7 +35,6 @@ class ClipVisionModel(): self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.model_type = config.get("model_type", "clip_vision_model") - self.resize_to_original = config.get("resize_to_original", False) self.config = config.copy() model_class = IMAGE_ENCODERS.get(self.model_type) if self.model_type == "siglip_vision_model": @@ -61,15 +58,11 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - H, W = image.shape[1], image.shape[2] if self.model_type == "siglip2_vision_model": pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float() else: pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) - if self.resize_to_original: - resized = torch.nn.functional.interpolate(out[0], size=(H, W), mode="bicubic", antialias=False) - out = (resized,) + out[1:] outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) @@ -137,9 +130,6 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") - elif "bb.layers.1.blocks.0.attn.relative_position_index" in sd: - json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "birefnet.json") - # Dinov2 elif 'encoder.layer.39.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fdeffea2d..a1fec3ee2 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from spandrel import ImageModelDescriptor from comfy.clip_vision import ClipVisionModel from comfy.clip_vision import Output as ClipVisionOutput_ + from comfy.bg_removal_model import BackgroundRemovalModel from comfy.controlnet import ControlNet from comfy.hooks import HookGroup, HookKeyframeGroup from comfy.model_patcher import ModelPatcher @@ -613,6 +614,11 @@ class Model(ComfyTypeIO): if TYPE_CHECKING: Type = ModelPatcher +@comfytype(io_type="BACKGROUND_REMOVAL") +class BackgroundRemoval(ComfyTypeIO): + if TYPE_CHECKING: + Type = BackgroundRemovalModel + @comfytype(io_type="CLIP_VISION") class ClipVision(ComfyTypeIO): if TYPE_CHECKING: @@ -2219,6 +2225,7 @@ __all__ = [ "ModelPatch", "ClipVision", "ClipVisionOutput", + "BackgroundRemoval", "AudioEncoder", "AudioEncoderOutput", "StyleModel", diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py new file mode 100644 index 000000000..28b7459aa --- /dev/null +++ b/comfy_extras/nodes_bg_removal.py @@ -0,0 +1,58 @@ +import folder_paths +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO +from comfy.bg_removal_model import load + + +class LoadBackGroundRemovalModel(IO.ComfyNode): + @classmethod + def define_schema(cls): + files = folder_paths.get_filename_list("background_removal") + return IO.Schema( + node_id="LoadBackGroundRemovalModel", + category="loaders", + inputs=[ + IO.Combo.Input("background_removal_name", options=sorted(files)), + ], + outputs=[ + IO.BackgroundRemoval.Output("bg_model") + ] + ) + @classmethod + def execute(cls, background_removal_name): + path = folder_paths.get_full_path_or_raise("background_removal", background_removal_name) + bg = load(path) + if bg is None: + raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") + return IO.NodeOutput(bg) + +class RemoveBackGround(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RemoveBackGround", + category="encode", + inputs=[ + IO.Image.Input("image"), + IO.BackgroundRemoval.Input("bg_removal_model") + ], + outputs=[ + IO.Mask.Output("mask") + ] + ) + @classmethod + def execute(cls, image, bg_removal_model): + mask = bg_removal_model.encode_image(image) + return IO.NodeOutput(mask) + +class BackGroundRemovalExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LoadBackGroundRemovalModel, + RemoveBackGround + ] + + +async def comfy_entrypoint() -> BackGroundRemovalExtension: + return BackGroundRemovalExtension() diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index cd8111683..e2b8844a4 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -390,29 +390,6 @@ class GrowMask(IO.ComfyNode): expand_mask = execute # TODO: remove -class ClipVisionToMask(IO.ComfyNode): - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ClipVisionToMask", - inputs = [ - IO.ClipVisionOutput.Input("clip_vision_output") - ], - outputs = [IO.Mask.Output("mask")] - ) - @classmethod - def execute(cls, clip_vision_output): - if not isinstance(clip_vision_output, torch.Tensor): - mask = clip_vision_output["last_hidden_state"] - mask = mask.sigmoid() - if mask.ndim == 3: - mask = mask.unsqueeze(0) - if mask.shape[1] != 1: - mask = mask.movedim(-1, 1) - return IO.NodeOutput(mask) - - clip_vision_to_mask = execute - class ThresholdMask(IO.ComfyNode): @classmethod def define_schema(cls): @@ -476,7 +453,6 @@ class MaskExtension(ComfyExtension): GrowMask, ThresholdMask, MaskPreview, - ClipVisionToMask ] diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3..63750125c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/background_removal/put_background_removal_models_here b/models/background_removal/put_background_removal_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 299b3d758..c1e5d5699 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_bg_removal.py" ] import_failed = []