mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
node tweaks
This commit is contained in:
parent
cd53b3404c
commit
cff13ace64
5
.gitignore
vendored
5
.gitignore
vendored
@ -2,13 +2,14 @@
|
|||||||
/[Oo]utput/
|
/[Oo]utput/
|
||||||
/[Ii]nput/
|
/[Ii]nput/
|
||||||
!/input/example.png
|
!/input/example.png
|
||||||
/[Mm]odels/
|
/[Mm]odels/*
|
||||||
|
![Mm]odels/deepfloyd/put_deepfloyd_hugginface_repos_or_diffusers_cache_here
|
||||||
|
|
||||||
/[Tt]emp/
|
/[Tt]emp/
|
||||||
/[Cc]ustom_nodes/*
|
/[Cc]ustom_nodes/*
|
||||||
![Cc]ustom_nodes/__init__.py
|
![Cc]ustom_nodes/__init__.py
|
||||||
!/custom_nodes/example_node.py.example
|
!/custom_nodes/example_node.py.example
|
||||||
**/put*here
|
**/put*here
|
||||||
![Mm]odels/deepfloyd/put_deepfloyd_repos_here
|
|
||||||
/extra_model_paths.yaml
|
/extra_model_paths.yaml
|
||||||
/.vs
|
/.vs
|
||||||
.idea/
|
.idea/
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
from typing import Protocol, ClassVar, Tuple, Dict
|
from typing import Protocol, ClassVar, Tuple, Dict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ class CustomNode(Protocol):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> dict: ...
|
def INPUT_TYPES(cls) -> dict: ...
|
||||||
|
|
||||||
RETURN_TYPES: ClassVar[Tuple[str]]
|
RETURN_TYPES: ClassVar[typing.Sequence[str]]
|
||||||
RETURN_NAMES: ClassVar[Tuple[str]] = None
|
RETURN_NAMES: ClassVar[Tuple[str]] = None
|
||||||
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
|
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
|
||||||
INPUT_IS_LIST: ClassVar[bool] = None
|
INPUT_IS_LIST: ClassVar[bool] = None
|
||||||
|
|||||||
@ -14,17 +14,17 @@ filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecat
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
# DeepFloyd
|
# DeepFloyd
|
||||||
"IF Loader": Loader,
|
"IFLoader": IFLoader,
|
||||||
"IF Encoder": Encoder,
|
"IFEncoder": IFEncoder,
|
||||||
"IF Stage I": StageI,
|
"IFStageI": IFStageI,
|
||||||
"IF Stage II": StageII,
|
"IFStageII": IFStageII,
|
||||||
"IF Stage III": StageIII,
|
"IFStageIII": IFStageIII,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"IF Loader": "IF Loader",
|
"IFLoader": "DeepFloyd IF Loader",
|
||||||
"IF Encoder": "IF Encoder",
|
"IFEncoder": "DeepFloyd IF Encoder",
|
||||||
"IF Stage I": "IF Stage I",
|
"IFStageI": "DeepFloyd IF Stage I",
|
||||||
"IF Stage II": "IF Stage II",
|
"IFStageII": "DeepFloyd IF Stage II",
|
||||||
"IF Stage III": "IF Stage III",
|
"IFStageIII": "DeepFloyd IF Stage III",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,13 +4,12 @@ import os.path
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
|
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
|
||||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||||
|
|
||||||
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
|
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
|
||||||
# todo: this relies on the setup-py cleanup fork
|
from comfy.nodes.package_typing import CustomNode
|
||||||
from comfy.utils import ProgressBar, get_project_root
|
from comfy.utils import ProgressBar, get_project_root
|
||||||
|
|
||||||
# todo: find or download the models automatically by their config jsons instead of using well known names
|
# todo: find or download the models automatically by their config jsons instead of using well known names
|
||||||
@ -83,13 +82,16 @@ def _cpu_offload(self: DiffusionPipeline, gpu_id=0):
|
|||||||
self.enable_model_cpu_offload(gpu_id)
|
self.enable_model_cpu_offload(gpu_id)
|
||||||
|
|
||||||
|
|
||||||
class Loader:
|
class IFLoader(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model_name": (Loader._MODELS, {"default": "I-M"}),
|
"model_name": (IFLoader._MODELS, {"default": "I-M"}),
|
||||||
"quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
|
"quantization": (list(IFLoader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"hugging_face_token": ("STRING", {"default": ""}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -110,9 +112,8 @@ class Loader:
|
|||||||
"16-bit": None,
|
"16-bit": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# todo: correctly use load_in_8bit
|
def process(self, model_name: str, quantization: str, hugging_face_token: str = ""):
|
||||||
def process(self, model_name: str, quantization: str):
|
assert model_name in IFLoader._MODELS
|
||||||
assert model_name in Loader._MODELS
|
|
||||||
|
|
||||||
model_v: DiffusionPipeline
|
model_v: DiffusionPipeline
|
||||||
model_path: str
|
model_path: str
|
||||||
@ -126,14 +127,22 @@ class Loader:
|
|||||||
"device_map": None
|
"device_map": None
|
||||||
}
|
}
|
||||||
|
|
||||||
if Loader._QUANTIZATIONS[quantization] is not None:
|
if hugging_face_token is not None and hugging_face_token != "":
|
||||||
kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization]
|
kwargs['access_token'] = hugging_face_token
|
||||||
|
elif 'HUGGING_FACE_HUB_TOKEN' in os.environ:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if IFLoader._QUANTIZATIONS[quantization] is not None:
|
||||||
|
kwargs['quantization_config'] = IFLoader._QUANTIZATIONS[quantization]
|
||||||
|
|
||||||
if model_name == "t5":
|
if model_name == "t5":
|
||||||
# find any valid IF model
|
# find any valid IF model
|
||||||
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
|
try:
|
||||||
any(x == T5EncoderModel.__name__ for x in
|
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
|
||||||
json.load(open(file, 'r'))["text_encoder"]))
|
any(x == T5EncoderModel.__name__ for x in
|
||||||
|
json.load(open(file, 'r'))["text_encoder"]))
|
||||||
|
except:
|
||||||
|
model_path = "DeepFloyd/IF-I-M-v1.0"
|
||||||
kwargs["unet"] = None
|
kwargs["unet"] = None
|
||||||
elif model_name == "III":
|
elif model_name == "III":
|
||||||
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
|
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
|
||||||
@ -142,6 +151,13 @@ class Loader:
|
|||||||
model_path = f"{_model_base_path}/IF-{model_name}-v1.0"
|
model_path = f"{_model_base_path}/IF-{model_name}-v1.0"
|
||||||
kwargs["text_encoder"] = None
|
kwargs["text_encoder"] = None
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
kwargs['cache_dir='] = os.path.abspath(_model_base_path)
|
||||||
|
if model_name == "t5":
|
||||||
|
model_path = "DeepFloyd/IF-I-M-v1.0"
|
||||||
|
else:
|
||||||
|
model_path = f"DeepFloyd/IF-{model_name}-v1.0"
|
||||||
|
|
||||||
model_v = DiffusionPipeline.from_pretrained(
|
model_v = DiffusionPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path=model_path,
|
pretrained_model_name_or_path=model_path,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -155,7 +171,7 @@ class Loader:
|
|||||||
return (model_v,)
|
return (model_v,)
|
||||||
|
|
||||||
|
|
||||||
class Encoder:
|
class IFEncoder(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -168,9 +184,7 @@ class Encoder:
|
|||||||
|
|
||||||
CATEGORY = "deepfloyd"
|
CATEGORY = "deepfloyd"
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
MODEL = None
|
|
||||||
RETURN_TYPES = ("POSITIVE", "NEGATIVE",)
|
RETURN_TYPES = ("POSITIVE", "NEGATIVE",)
|
||||||
TEXT_ENCODER = None
|
|
||||||
|
|
||||||
def process(self, model: IFPipeline, positive, negative):
|
def process(self, model: IFPipeline, positive, negative):
|
||||||
positive, negative = model.encode_prompt(
|
positive, negative = model.encode_prompt(
|
||||||
@ -181,7 +195,7 @@ class Encoder:
|
|||||||
return (positive, negative,)
|
return (positive, negative,)
|
||||||
|
|
||||||
|
|
||||||
class StageI:
|
class IFStageI:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -228,7 +242,7 @@ class StageI:
|
|||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
class StageII:
|
class IFStageII:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -251,10 +265,7 @@ class StageII:
|
|||||||
def process(self, model, images, positive, negative, seed, steps, cfg):
|
def process(self, model, images, positive, negative, seed, steps, cfg):
|
||||||
images = images.permute(0, 3, 1, 2)
|
images = images.permute(0, 3, 1, 2)
|
||||||
progress = ProgressBar(steps)
|
progress = ProgressBar(steps)
|
||||||
batch_size, channels, height, width = images.shape
|
batch_size = images.shape[0]
|
||||||
max_dim = max(height, width)
|
|
||||||
images = TF.center_crop(images, max_dim)
|
|
||||||
model.unet.config.sample_size = max_dim * 4
|
|
||||||
|
|
||||||
if batch_size > 1:
|
if batch_size > 1:
|
||||||
positive = positive.repeat(batch_size, 1, 1)
|
positive = positive.repeat(batch_size, 1, 1)
|
||||||
@ -268,19 +279,22 @@ class StageII:
|
|||||||
image=images,
|
image=images,
|
||||||
prompt_embeds=positive,
|
prompt_embeds=positive,
|
||||||
negative_prompt_embeds=negative,
|
negative_prompt_embeds=negative,
|
||||||
|
height=images.shape[2] // 8 * 8 * 4,
|
||||||
|
width=images.shape[3] // 8 * 8 * 4,
|
||||||
generator=torch.manual_seed(seed),
|
generator=torch.manual_seed(seed),
|
||||||
guidance_scale=cfg,
|
guidance_scale=cfg,
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
output_type="pt",
|
output_type="pt",
|
||||||
).images.cpu().float()
|
).images
|
||||||
|
|
||||||
images = TF.center_crop(images, [height * 4, width * 4])
|
images = images.clamp(0, 1)
|
||||||
images = images.permute(0, 2, 3, 1)
|
images = images.permute(0, 2, 3, 1)
|
||||||
|
images = images.to("cpu", torch.float32)
|
||||||
return (images,)
|
return (images,)
|
||||||
|
|
||||||
|
|
||||||
class StageIII:
|
class IFStageIII:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -11,9 +11,12 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
MAX_RESOLUTION = 1024
|
MAX_RESOLUTION = 1024
|
||||||
AUTO_FACTOR = 8
|
AUTO_FACTOR = 8
|
||||||
|
|
||||||
|
|
||||||
def k_centroid_downscale(images, width, height, centroids=2):
|
def k_centroid_downscale(images, width, height, centroids=2):
|
||||||
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
|
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
|
||||||
|
|
||||||
@ -31,13 +34,13 @@ def k_centroid_downscale(images, width, height, centroids=2):
|
|||||||
# get most common (median) color
|
# get most common (median) color
|
||||||
color_counts = tile.getcolors()
|
color_counts = tile.getcolors()
|
||||||
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
|
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
|
||||||
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
|
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx * 3:(most_common_idx + 1) * 3]
|
||||||
|
|
||||||
downscaled = downscaled.astype(np.float32) / 255.0
|
downscaled = downscaled.astype(np.float32) / 255.0
|
||||||
return torch.from_numpy(downscaled)
|
return torch.from_numpy(downscaled)
|
||||||
|
|
||||||
|
|
||||||
class ImageKCentroidDownscale:
|
class ImageKCentroidDownscale(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -58,7 +61,8 @@ class ImageKCentroidDownscale:
|
|||||||
s = k_centroid_downscale(image, width, height, centroids)
|
s = k_centroid_downscale(image, width, height, centroids)
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ImageKCentroidAutoDownscale:
|
|
||||||
|
class ImageKCentroidAutoDownscale(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -10,6 +10,8 @@ import torch
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.modules.utils import _pair
|
from torch.nn.modules.utils import _pair
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
def flatten_modules(m):
|
def flatten_modules(m):
|
||||||
'''Return submodules of module m in flattened form.'''
|
'''Return submodules of module m in flattened form.'''
|
||||||
@ -38,7 +40,7 @@ def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tens
|
|||||||
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
class MakeModelTileable:
|
class MakeModelTileable(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -7,8 +7,10 @@ import numpy as np
|
|||||||
import rembg
|
import rembg
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
class BinarizeMask:
|
|
||||||
|
class BinarizeMask(CustomNode):
|
||||||
'''Binarize (threshold) a mask.'''
|
'''Binarize (threshold) a mask.'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -36,7 +38,7 @@ class BinarizeMask:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
class ImageCutout:
|
class ImageCutout(CustomNode):
|
||||||
'''Perform basic image cutout (adds alpha channel from mask).'''
|
'''Perform basic image cutout (adds alpha channel from mask).'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -65,4 +67,3 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"BinarizeMask": BinarizeMask,
|
"BinarizeMask": BinarizeMask,
|
||||||
"ImageCutout": ImageCutout,
|
"ImageCutout": ImageCutout,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../../..', 'palettes')
|
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../../..', 'palettes')
|
||||||
PAL_EXT = '.png'
|
PAL_EXT = '.png'
|
||||||
|
|
||||||
@ -18,6 +20,7 @@ QUANTIZE_METHODS = {
|
|||||||
'fast_octree': Image.Quantize.FASTOCTREE,
|
'fast_octree': Image.Quantize.FASTOCTREE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Determine optimal number of colors.
|
# Determine optimal number of colors.
|
||||||
# FROM: astropulse/sd-palettize
|
# FROM: astropulse/sd-palettize
|
||||||
#
|
#
|
||||||
@ -59,8 +62,10 @@ def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
|
|||||||
|
|
||||||
return best_k
|
return best_k
|
||||||
|
|
||||||
|
|
||||||
palette_warned = False
|
palette_warned = False
|
||||||
|
|
||||||
|
|
||||||
def list_palettes():
|
def list_palettes():
|
||||||
global palette_warned
|
global palette_warned
|
||||||
palettes = []
|
palettes = []
|
||||||
@ -72,7 +77,8 @@ def list_palettes():
|
|||||||
pass
|
pass
|
||||||
if not palettes and not palette_warned:
|
if not palettes and not palette_warned:
|
||||||
palette_warned = True
|
palette_warned = True
|
||||||
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
print(
|
||||||
|
"ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
||||||
return palettes
|
return palettes
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +96,7 @@ def load_palette(name):
|
|||||||
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
|
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
|
||||||
|
|
||||||
|
|
||||||
class ImagePalettize:
|
class ImagePalettize(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -122,7 +128,7 @@ class ImagePalettize:
|
|||||||
if palette not in {'auto_best_k', 'auto_fixed_k'}:
|
if palette not in {'auto_best_k', 'auto_fixed_k'}:
|
||||||
pal_entries = load_palette(palette)
|
pal_entries = load_palette(palette)
|
||||||
k = len(pal_entries) // 3
|
k = len(pal_entries) // 3
|
||||||
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
||||||
pal_img.putpalette(pal_entries)
|
pal_img.putpalette(pal_entries)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@ -143,7 +149,7 @@ class ImagePalettize:
|
|||||||
results.append(np.array(i))
|
results.append(np.array(i))
|
||||||
|
|
||||||
result = np.array(results).astype(np.float32) / 255.0
|
result = np.array(results).astype(np.float32) / 255.0
|
||||||
return (torch.from_numpy(result), )
|
return (torch.from_numpy(result),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -9,9 +9,12 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
MAX_RESOLUTION = 8192
|
MAX_RESOLUTION = 8192
|
||||||
|
|
||||||
class ImageSolidColor:
|
|
||||||
|
class ImageSolidColor(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -32,7 +35,7 @@ class ImageSolidColor:
|
|||||||
def render(self, width, height, r, g, b):
|
def render(self, width, height, r, g, b):
|
||||||
color = torch.tensor([r, g, b]) / 255.0
|
color = torch.tensor([r, g, b]) / 255.0
|
||||||
result = color.expand(1, height, width, 3)
|
result = color.expand(1, height, width, 3)
|
||||||
return (result, )
|
return (result,)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@ -42,4 +45,3 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"ImageSolidColor": "Solid Color",
|
"ImageSolidColor": "Solid Color",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,12 @@ import numpy as np
|
|||||||
import rembg
|
import rembg
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import CustomNode
|
||||||
|
|
||||||
MODELS = rembg.sessions.sessions_names
|
MODELS = rembg.sessions.sessions_names
|
||||||
|
|
||||||
|
|
||||||
class ImageRemoveBackground:
|
class ImageRemoveBackground(CustomNode):
|
||||||
'''Remove background from image (adds an alpha channel)'''
|
'''Remove background from image (adds an alpha channel)'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -59,17 +60,18 @@ class ImageRemoveBackground:
|
|||||||
i = 255. * i.cpu().numpy()
|
i = 255. * i.cpu().numpy()
|
||||||
i = np.clip(i, 0, 255).astype(np.uint8)
|
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||||
i = rembg.remove(i,
|
i = rembg.remove(i,
|
||||||
alpha_matting=(alpha_matting == "enabled"),
|
alpha_matting=(alpha_matting == "enabled"),
|
||||||
alpha_matting_foreground_threshold=am_foreground_thr,
|
alpha_matting_foreground_threshold=am_foreground_thr,
|
||||||
alpha_matting_background_threshold=am_background_thr,
|
alpha_matting_background_threshold=am_background_thr,
|
||||||
alpha_matting_erode_size=am_erode_size,
|
alpha_matting_erode_size=am_erode_size,
|
||||||
session=session,
|
session=session,
|
||||||
)
|
)
|
||||||
results.append(i.astype(np.float32) / 255.0)
|
results.append(i.astype(np.float32) / 255.0)
|
||||||
|
|
||||||
s = torch.from_numpy(np.array(results))
|
s = torch.from_numpy(np.array(results))
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
class ImageEstimateForegroundMask:
|
class ImageEstimateForegroundMask:
|
||||||
'''
|
'''
|
||||||
Return a mask of which pixels are estimated to belong to foreground.
|
Return a mask of which pixels are estimated to belong to foreground.
|
||||||
|
|||||||
@ -29,5 +29,4 @@ diffusers>=0.16.1
|
|||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
rembg
|
rembg
|
||||||
psutil
|
psutil
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl; platform_system == "Windows"
|
bitsandbytes>=0.40.1
|
||||||
bitsandbytes; platform_system != "Windows"
|
|
||||||
7
setup.py
7
setup.py
@ -54,6 +54,11 @@ Packages that should have a specific option set when a GPU accelerator is presen
|
|||||||
"""
|
"""
|
||||||
gpu_accelerated_packages = {"rembg": "rembg[gpu]"}
|
gpu_accelerated_packages = {"rembg": "rembg[gpu]"}
|
||||||
|
|
||||||
|
"""
|
||||||
|
The URL to the bitsandbytes package to use on Windows
|
||||||
|
"""
|
||||||
|
bitsandbytes_windows = "https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Indicates if we're installing an editable (develop) mode package
|
Indicates if we're installing an editable (develop) mode package
|
||||||
"""
|
"""
|
||||||
@ -152,6 +157,8 @@ def dependencies() -> [str]:
|
|||||||
requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}")
|
requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}")
|
||||||
candidate = finder.find_best_candidate(requirement.name, requirement.specifier)
|
candidate = finder.find_best_candidate(requirement.name, requirement.specifier)
|
||||||
if candidate.best_candidate is not None:
|
if candidate.best_candidate is not None:
|
||||||
|
if requirement.name == "bitsandbytes" and platform.system().lower() == 'windows':
|
||||||
|
_dependencies[i] = f"{requirement.name} @ {bitsandbytes_windows}"
|
||||||
if gpu_accelerated and requirement.name in gpu_accelerated_packages:
|
if gpu_accelerated and requirement.name in gpu_accelerated_packages:
|
||||||
_dependencies[i] = gpu_accelerated_packages[requirement.name]
|
_dependencies[i] = gpu_accelerated_packages[requirement.name]
|
||||||
if any([url in candidate.best_candidate.link.url for url in _alternative_indices]):
|
if any([url in candidate.best_candidate.link.url for url in _alternative_indices]):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user