mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
node tweaks
This commit is contained in:
parent
cd53b3404c
commit
cff13ace64
5
.gitignore
vendored
5
.gitignore
vendored
@ -2,13 +2,14 @@
|
||||
/[Oo]utput/
|
||||
/[Ii]nput/
|
||||
!/input/example.png
|
||||
/[Mm]odels/
|
||||
/[Mm]odels/*
|
||||
![Mm]odels/deepfloyd/put_deepfloyd_hugginface_repos_or_diffusers_cache_here
|
||||
|
||||
/[Tt]emp/
|
||||
/[Cc]ustom_nodes/*
|
||||
![Cc]ustom_nodes/__init__.py
|
||||
!/custom_nodes/example_node.py.example
|
||||
**/put*here
|
||||
![Mm]odels/deepfloyd/put_deepfloyd_repos_here
|
||||
/extra_model_paths.yaml
|
||||
/.vs
|
||||
.idea/
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Protocol, ClassVar, Tuple, Dict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@ -8,7 +9,7 @@ class CustomNode(Protocol):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict: ...
|
||||
|
||||
RETURN_TYPES: ClassVar[Tuple[str]]
|
||||
RETURN_TYPES: ClassVar[typing.Sequence[str]]
|
||||
RETURN_NAMES: ClassVar[Tuple[str]] = None
|
||||
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
|
||||
INPUT_IS_LIST: ClassVar[bool] = None
|
||||
|
||||
@ -14,17 +14,17 @@ filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecat
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
# DeepFloyd
|
||||
"IF Loader": Loader,
|
||||
"IF Encoder": Encoder,
|
||||
"IF Stage I": StageI,
|
||||
"IF Stage II": StageII,
|
||||
"IF Stage III": StageIII,
|
||||
"IFLoader": IFLoader,
|
||||
"IFEncoder": IFEncoder,
|
||||
"IFStageI": IFStageI,
|
||||
"IFStageII": IFStageII,
|
||||
"IFStageIII": IFStageIII,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"IF Loader": "IF Loader",
|
||||
"IF Encoder": "IF Encoder",
|
||||
"IF Stage I": "IF Stage I",
|
||||
"IF Stage II": "IF Stage II",
|
||||
"IF Stage III": "IF Stage III",
|
||||
"IFLoader": "DeepFloyd IF Loader",
|
||||
"IFEncoder": "DeepFloyd IF Encoder",
|
||||
"IFStageI": "DeepFloyd IF Stage I",
|
||||
"IFStageII": "DeepFloyd IF Stage II",
|
||||
"IFStageIII": "DeepFloyd IF Stage III",
|
||||
}
|
||||
|
||||
@ -4,13 +4,12 @@ import os.path
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
|
||||
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
|
||||
# todo: this relies on the setup-py cleanup fork
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy.utils import ProgressBar, get_project_root
|
||||
|
||||
# todo: find or download the models automatically by their config jsons instead of using well known names
|
||||
@ -83,13 +82,16 @@ def _cpu_offload(self: DiffusionPipeline, gpu_id=0):
|
||||
self.enable_model_cpu_offload(gpu_id)
|
||||
|
||||
|
||||
class Loader:
|
||||
class IFLoader(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (Loader._MODELS, {"default": "I-M"}),
|
||||
"quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
|
||||
"model_name": (IFLoader._MODELS, {"default": "I-M"}),
|
||||
"quantization": (list(IFLoader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
|
||||
},
|
||||
"optional": {
|
||||
"hugging_face_token": ("STRING", {"default": ""}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,9 +112,8 @@ class Loader:
|
||||
"16-bit": None,
|
||||
}
|
||||
|
||||
# todo: correctly use load_in_8bit
|
||||
def process(self, model_name: str, quantization: str):
|
||||
assert model_name in Loader._MODELS
|
||||
def process(self, model_name: str, quantization: str, hugging_face_token: str = ""):
|
||||
assert model_name in IFLoader._MODELS
|
||||
|
||||
model_v: DiffusionPipeline
|
||||
model_path: str
|
||||
@ -126,14 +127,22 @@ class Loader:
|
||||
"device_map": None
|
||||
}
|
||||
|
||||
if Loader._QUANTIZATIONS[quantization] is not None:
|
||||
kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization]
|
||||
if hugging_face_token is not None and hugging_face_token != "":
|
||||
kwargs['access_token'] = hugging_face_token
|
||||
elif 'HUGGING_FACE_HUB_TOKEN' in os.environ:
|
||||
pass
|
||||
|
||||
if IFLoader._QUANTIZATIONS[quantization] is not None:
|
||||
kwargs['quantization_config'] = IFLoader._QUANTIZATIONS[quantization]
|
||||
|
||||
if model_name == "t5":
|
||||
# find any valid IF model
|
||||
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
|
||||
any(x == T5EncoderModel.__name__ for x in
|
||||
json.load(open(file, 'r'))["text_encoder"]))
|
||||
try:
|
||||
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
|
||||
any(x == T5EncoderModel.__name__ for x in
|
||||
json.load(open(file, 'r'))["text_encoder"]))
|
||||
except:
|
||||
model_path = "DeepFloyd/IF-I-M-v1.0"
|
||||
kwargs["unet"] = None
|
||||
elif model_name == "III":
|
||||
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
|
||||
@ -142,6 +151,13 @@ class Loader:
|
||||
model_path = f"{_model_base_path}/IF-{model_name}-v1.0"
|
||||
kwargs["text_encoder"] = None
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
kwargs['cache_dir='] = os.path.abspath(_model_base_path)
|
||||
if model_name == "t5":
|
||||
model_path = "DeepFloyd/IF-I-M-v1.0"
|
||||
else:
|
||||
model_path = f"DeepFloyd/IF-{model_name}-v1.0"
|
||||
|
||||
model_v = DiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=model_path,
|
||||
**kwargs
|
||||
@ -155,7 +171,7 @@ class Loader:
|
||||
return (model_v,)
|
||||
|
||||
|
||||
class Encoder:
|
||||
class IFEncoder(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -168,9 +184,7 @@ class Encoder:
|
||||
|
||||
CATEGORY = "deepfloyd"
|
||||
FUNCTION = "process"
|
||||
MODEL = None
|
||||
RETURN_TYPES = ("POSITIVE", "NEGATIVE",)
|
||||
TEXT_ENCODER = None
|
||||
|
||||
def process(self, model: IFPipeline, positive, negative):
|
||||
positive, negative = model.encode_prompt(
|
||||
@ -181,7 +195,7 @@ class Encoder:
|
||||
return (positive, negative,)
|
||||
|
||||
|
||||
class StageI:
|
||||
class IFStageI:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -228,7 +242,7 @@ class StageI:
|
||||
return (image,)
|
||||
|
||||
|
||||
class StageII:
|
||||
class IFStageII:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -251,10 +265,7 @@ class StageII:
|
||||
def process(self, model, images, positive, negative, seed, steps, cfg):
|
||||
images = images.permute(0, 3, 1, 2)
|
||||
progress = ProgressBar(steps)
|
||||
batch_size, channels, height, width = images.shape
|
||||
max_dim = max(height, width)
|
||||
images = TF.center_crop(images, max_dim)
|
||||
model.unet.config.sample_size = max_dim * 4
|
||||
batch_size = images.shape[0]
|
||||
|
||||
if batch_size > 1:
|
||||
positive = positive.repeat(batch_size, 1, 1)
|
||||
@ -268,19 +279,22 @@ class StageII:
|
||||
image=images,
|
||||
prompt_embeds=positive,
|
||||
negative_prompt_embeds=negative,
|
||||
height=images.shape[2] // 8 * 8 * 4,
|
||||
width=images.shape[3] // 8 * 8 * 4,
|
||||
generator=torch.manual_seed(seed),
|
||||
guidance_scale=cfg,
|
||||
num_inference_steps=steps,
|
||||
callback=callback,
|
||||
output_type="pt",
|
||||
).images.cpu().float()
|
||||
).images
|
||||
|
||||
images = TF.center_crop(images, [height * 4, width * 4])
|
||||
images = images.clamp(0, 1)
|
||||
images = images.permute(0, 2, 3, 1)
|
||||
images = images.to("cpu", torch.float32)
|
||||
return (images,)
|
||||
|
||||
|
||||
class StageIII:
|
||||
class IFStageIII:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
@ -11,9 +11,12 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
MAX_RESOLUTION = 1024
|
||||
AUTO_FACTOR = 8
|
||||
|
||||
|
||||
def k_centroid_downscale(images, width, height, centroids=2):
|
||||
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
|
||||
|
||||
@ -31,13 +34,13 @@ def k_centroid_downscale(images, width, height, centroids=2):
|
||||
# get most common (median) color
|
||||
color_counts = tile.getcolors()
|
||||
most_common_idx = max(color_counts, key=lambda x: x[0])[1]
|
||||
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3]
|
||||
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx * 3:(most_common_idx + 1) * 3]
|
||||
|
||||
downscaled = downscaled.astype(np.float32) / 255.0
|
||||
return torch.from_numpy(downscaled)
|
||||
|
||||
|
||||
class ImageKCentroidDownscale:
|
||||
class ImageKCentroidDownscale(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -58,7 +61,8 @@ class ImageKCentroidDownscale:
|
||||
s = k_centroid_downscale(image, width, height, centroids)
|
||||
return (s,)
|
||||
|
||||
class ImageKCentroidAutoDownscale:
|
||||
|
||||
class ImageKCentroidAutoDownscale(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
@ -10,6 +10,8 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
def flatten_modules(m):
|
||||
'''Return submodules of module m in flattened form.'''
|
||||
@ -38,7 +40,7 @@ def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tens
|
||||
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
||||
|
||||
|
||||
class MakeModelTileable:
|
||||
class MakeModelTileable(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
|
||||
@ -7,8 +7,10 @@ import numpy as np
|
||||
import rembg
|
||||
import torch
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
class BinarizeMask:
|
||||
|
||||
class BinarizeMask(CustomNode):
|
||||
'''Binarize (threshold) a mask.'''
|
||||
|
||||
@classmethod
|
||||
@ -36,7 +38,7 @@ class BinarizeMask:
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageCutout:
|
||||
class ImageCutout(CustomNode):
|
||||
'''Perform basic image cutout (adds alpha channel from mask).'''
|
||||
|
||||
@classmethod
|
||||
@ -65,4 +67,3 @@ NODE_CLASS_MAPPINGS = {
|
||||
"BinarizeMask": BinarizeMask,
|
||||
"ImageCutout": ImageCutout,
|
||||
}
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../../..', 'palettes')
|
||||
PAL_EXT = '.png'
|
||||
|
||||
@ -18,6 +20,7 @@ QUANTIZE_METHODS = {
|
||||
'fast_octree': Image.Quantize.FASTOCTREE,
|
||||
}
|
||||
|
||||
|
||||
# Determine optimal number of colors.
|
||||
# FROM: astropulse/sd-palettize
|
||||
#
|
||||
@ -59,8 +62,10 @@ def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
|
||||
|
||||
return best_k
|
||||
|
||||
|
||||
palette_warned = False
|
||||
|
||||
|
||||
def list_palettes():
|
||||
global palette_warned
|
||||
palettes = []
|
||||
@ -72,7 +77,8 @@ def list_palettes():
|
||||
pass
|
||||
if not palettes and not palette_warned:
|
||||
palette_warned = True
|
||||
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
||||
print(
|
||||
"ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
|
||||
return palettes
|
||||
|
||||
|
||||
@ -90,7 +96,7 @@ def load_palette(name):
|
||||
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
|
||||
|
||||
|
||||
class ImagePalettize:
|
||||
class ImagePalettize(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -122,7 +128,7 @@ class ImagePalettize:
|
||||
if palette not in {'auto_best_k', 'auto_fixed_k'}:
|
||||
pal_entries = load_palette(palette)
|
||||
k = len(pal_entries) // 3
|
||||
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
||||
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
|
||||
pal_img.putpalette(pal_entries)
|
||||
|
||||
results = []
|
||||
@ -143,7 +149,7 @@ class ImagePalettize:
|
||||
results.append(np.array(i))
|
||||
|
||||
result = np.array(results).astype(np.float32) / 255.0
|
||||
return (torch.from_numpy(result), )
|
||||
return (torch.from_numpy(result),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -9,9 +9,12 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
MAX_RESOLUTION = 8192
|
||||
|
||||
class ImageSolidColor:
|
||||
|
||||
class ImageSolidColor(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -32,7 +35,7 @@ class ImageSolidColor:
|
||||
def render(self, width, height, r, g, b):
|
||||
color = torch.tensor([r, g, b]) / 255.0
|
||||
result = color.expand(1, height, width, 3)
|
||||
return (result, )
|
||||
return (result,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@ -42,4 +45,3 @@ NODE_CLASS_MAPPINGS = {
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageSolidColor": "Solid Color",
|
||||
}
|
||||
|
||||
|
||||
@ -7,11 +7,12 @@ import numpy as np
|
||||
import rembg
|
||||
import torch
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
MODELS = rembg.sessions.sessions_names
|
||||
|
||||
|
||||
class ImageRemoveBackground:
|
||||
class ImageRemoveBackground(CustomNode):
|
||||
'''Remove background from image (adds an alpha channel)'''
|
||||
|
||||
@classmethod
|
||||
@ -59,17 +60,18 @@ class ImageRemoveBackground:
|
||||
i = 255. * i.cpu().numpy()
|
||||
i = np.clip(i, 0, 255).astype(np.uint8)
|
||||
i = rembg.remove(i,
|
||||
alpha_matting=(alpha_matting == "enabled"),
|
||||
alpha_matting_foreground_threshold=am_foreground_thr,
|
||||
alpha_matting_background_threshold=am_background_thr,
|
||||
alpha_matting_erode_size=am_erode_size,
|
||||
session=session,
|
||||
)
|
||||
alpha_matting=(alpha_matting == "enabled"),
|
||||
alpha_matting_foreground_threshold=am_foreground_thr,
|
||||
alpha_matting_background_threshold=am_background_thr,
|
||||
alpha_matting_erode_size=am_erode_size,
|
||||
session=session,
|
||||
)
|
||||
results.append(i.astype(np.float32) / 255.0)
|
||||
|
||||
s = torch.from_numpy(np.array(results))
|
||||
return (s,)
|
||||
|
||||
|
||||
class ImageEstimateForegroundMask:
|
||||
'''
|
||||
Return a mask of which pixels are estimated to belong to foreground.
|
||||
|
||||
@ -29,5 +29,4 @@ diffusers>=0.16.1
|
||||
protobuf==3.20.3
|
||||
rembg
|
||||
psutil
|
||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||
bitsandbytes; platform_system != "Windows"
|
||||
bitsandbytes>=0.40.1
|
||||
7
setup.py
7
setup.py
@ -54,6 +54,11 @@ Packages that should have a specific option set when a GPU accelerator is presen
|
||||
"""
|
||||
gpu_accelerated_packages = {"rembg": "rembg[gpu]"}
|
||||
|
||||
"""
|
||||
The URL to the bitsandbytes package to use on Windows
|
||||
"""
|
||||
bitsandbytes_windows = "https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl"
|
||||
|
||||
"""
|
||||
Indicates if we're installing an editable (develop) mode package
|
||||
"""
|
||||
@ -152,6 +157,8 @@ def dependencies() -> [str]:
|
||||
requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}")
|
||||
candidate = finder.find_best_candidate(requirement.name, requirement.specifier)
|
||||
if candidate.best_candidate is not None:
|
||||
if requirement.name == "bitsandbytes" and platform.system().lower() == 'windows':
|
||||
_dependencies[i] = f"{requirement.name} @ {bitsandbytes_windows}"
|
||||
if gpu_accelerated and requirement.name in gpu_accelerated_packages:
|
||||
_dependencies[i] = gpu_accelerated_packages[requirement.name]
|
||||
if any([url in candidate.best_candidate.link.url for url in _alternative_indices]):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user