node tweaks

This commit is contained in:
Benjamin Berman 2023-08-21 11:44:51 -07:00
parent cd53b3404c
commit cff13ace64
12 changed files with 100 additions and 61 deletions

5
.gitignore vendored
View File

@ -2,13 +2,14 @@
/[Oo]utput/ /[Oo]utput/
/[Ii]nput/ /[Ii]nput/
!/input/example.png !/input/example.png
/[Mm]odels/ /[Mm]odels/*
![Mm]odels/deepfloyd/put_deepfloyd_hugginface_repos_or_diffusers_cache_here
/[Tt]emp/ /[Tt]emp/
/[Cc]ustom_nodes/* /[Cc]ustom_nodes/*
![Cc]ustom_nodes/__init__.py ![Cc]ustom_nodes/__init__.py
!/custom_nodes/example_node.py.example !/custom_nodes/example_node.py.example
**/put*here **/put*here
![Mm]odels/deepfloyd/put_deepfloyd_repos_here
/extra_model_paths.yaml /extra_model_paths.yaml
/.vs /.vs
.idea/ .idea/

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import typing
from typing import Protocol, ClassVar, Tuple, Dict from typing import Protocol, ClassVar, Tuple, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -8,7 +9,7 @@ class CustomNode(Protocol):
@classmethod @classmethod
def INPUT_TYPES(cls) -> dict: ... def INPUT_TYPES(cls) -> dict: ...
RETURN_TYPES: ClassVar[Tuple[str]] RETURN_TYPES: ClassVar[typing.Sequence[str]]
RETURN_NAMES: ClassVar[Tuple[str]] = None RETURN_NAMES: ClassVar[Tuple[str]] = None
OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None
INPUT_IS_LIST: ClassVar[bool] = None INPUT_IS_LIST: ClassVar[bool] = None

View File

@ -14,17 +14,17 @@ filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecat
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
# DeepFloyd # DeepFloyd
"IF Loader": Loader, "IFLoader": IFLoader,
"IF Encoder": Encoder, "IFEncoder": IFEncoder,
"IF Stage I": StageI, "IFStageI": IFStageI,
"IF Stage II": StageII, "IFStageII": IFStageII,
"IF Stage III": StageIII, "IFStageIII": IFStageIII,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"IF Loader": "IF Loader", "IFLoader": "DeepFloyd IF Loader",
"IF Encoder": "IF Encoder", "IFEncoder": "DeepFloyd IF Encoder",
"IF Stage I": "IF Stage I", "IFStageI": "DeepFloyd IF Stage I",
"IF Stage II": "IF Stage II", "IFStageII": "DeepFloyd IF Stage II",
"IF Stage III": "IF Stage III", "IFStageIII": "DeepFloyd IF Stage III",
} }

View File

@ -4,13 +4,12 @@ import os.path
import typing import typing
import torch import torch
import torchvision.transforms.functional as TF
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from transformers import T5EncoderModel, BitsAndBytesConfig from transformers import T5EncoderModel, BitsAndBytesConfig
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
# todo: this relies on the setup-py cleanup fork from comfy.nodes.package_typing import CustomNode
from comfy.utils import ProgressBar, get_project_root from comfy.utils import ProgressBar, get_project_root
# todo: find or download the models automatically by their config jsons instead of using well known names # todo: find or download the models automatically by their config jsons instead of using well known names
@ -83,13 +82,16 @@ def _cpu_offload(self: DiffusionPipeline, gpu_id=0):
self.enable_model_cpu_offload(gpu_id) self.enable_model_cpu_offload(gpu_id)
class Loader: class IFLoader(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"model_name": (Loader._MODELS, {"default": "I-M"}), "model_name": (IFLoader._MODELS, {"default": "I-M"}),
"quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}), "quantization": (list(IFLoader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
},
"optional": {
"hugging_face_token": ("STRING", {"default": ""}),
} }
} }
@ -110,9 +112,8 @@ class Loader:
"16-bit": None, "16-bit": None,
} }
# todo: correctly use load_in_8bit def process(self, model_name: str, quantization: str, hugging_face_token: str = ""):
def process(self, model_name: str, quantization: str): assert model_name in IFLoader._MODELS
assert model_name in Loader._MODELS
model_v: DiffusionPipeline model_v: DiffusionPipeline
model_path: str model_path: str
@ -126,14 +127,22 @@ class Loader:
"device_map": None "device_map": None
} }
if Loader._QUANTIZATIONS[quantization] is not None: if hugging_face_token is not None and hugging_face_token != "":
kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization] kwargs['access_token'] = hugging_face_token
elif 'HUGGING_FACE_HUB_TOKEN' in os.environ:
pass
if IFLoader._QUANTIZATIONS[quantization] is not None:
kwargs['quantization_config'] = IFLoader._QUANTIZATIONS[quantization]
if model_name == "t5": if model_name == "t5":
# find any valid IF model # find any valid IF model
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if try:
any(x == T5EncoderModel.__name__ for x in model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
json.load(open(file, 'r'))["text_encoder"])) any(x == T5EncoderModel.__name__ for x in
json.load(open(file, 'r'))["text_encoder"]))
except:
model_path = "DeepFloyd/IF-I-M-v1.0"
kwargs["unet"] = None kwargs["unet"] = None
elif model_name == "III": elif model_name == "III":
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler" model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
@ -142,6 +151,13 @@ class Loader:
model_path = f"{_model_base_path}/IF-{model_name}-v1.0" model_path = f"{_model_base_path}/IF-{model_name}-v1.0"
kwargs["text_encoder"] = None kwargs["text_encoder"] = None
if not os.path.exists(model_path):
kwargs['cache_dir='] = os.path.abspath(_model_base_path)
if model_name == "t5":
model_path = "DeepFloyd/IF-I-M-v1.0"
else:
model_path = f"DeepFloyd/IF-{model_name}-v1.0"
model_v = DiffusionPipeline.from_pretrained( model_v = DiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=model_path, pretrained_model_name_or_path=model_path,
**kwargs **kwargs
@ -155,7 +171,7 @@ class Loader:
return (model_v,) return (model_v,)
class Encoder: class IFEncoder(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -168,9 +184,7 @@ class Encoder:
CATEGORY = "deepfloyd" CATEGORY = "deepfloyd"
FUNCTION = "process" FUNCTION = "process"
MODEL = None
RETURN_TYPES = ("POSITIVE", "NEGATIVE",) RETURN_TYPES = ("POSITIVE", "NEGATIVE",)
TEXT_ENCODER = None
def process(self, model: IFPipeline, positive, negative): def process(self, model: IFPipeline, positive, negative):
positive, negative = model.encode_prompt( positive, negative = model.encode_prompt(
@ -181,7 +195,7 @@ class Encoder:
return (positive, negative,) return (positive, negative,)
class StageI: class IFStageI:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -228,7 +242,7 @@ class StageI:
return (image,) return (image,)
class StageII: class IFStageII:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -251,10 +265,7 @@ class StageII:
def process(self, model, images, positive, negative, seed, steps, cfg): def process(self, model, images, positive, negative, seed, steps, cfg):
images = images.permute(0, 3, 1, 2) images = images.permute(0, 3, 1, 2)
progress = ProgressBar(steps) progress = ProgressBar(steps)
batch_size, channels, height, width = images.shape batch_size = images.shape[0]
max_dim = max(height, width)
images = TF.center_crop(images, max_dim)
model.unet.config.sample_size = max_dim * 4
if batch_size > 1: if batch_size > 1:
positive = positive.repeat(batch_size, 1, 1) positive = positive.repeat(batch_size, 1, 1)
@ -268,19 +279,22 @@ class StageII:
image=images, image=images,
prompt_embeds=positive, prompt_embeds=positive,
negative_prompt_embeds=negative, negative_prompt_embeds=negative,
height=images.shape[2] // 8 * 8 * 4,
width=images.shape[3] // 8 * 8 * 4,
generator=torch.manual_seed(seed), generator=torch.manual_seed(seed),
guidance_scale=cfg, guidance_scale=cfg,
num_inference_steps=steps, num_inference_steps=steps,
callback=callback, callback=callback,
output_type="pt", output_type="pt",
).images.cpu().float() ).images
images = TF.center_crop(images, [height * 4, width * 4]) images = images.clamp(0, 1)
images = images.permute(0, 2, 3, 1) images = images.permute(0, 2, 3, 1)
images = images.to("cpu", torch.float32)
return (images,) return (images,)
class StageIII: class IFStageIII:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {

View File

@ -11,9 +11,12 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
from comfy.nodes.package_typing import CustomNode
MAX_RESOLUTION = 1024 MAX_RESOLUTION = 1024
AUTO_FACTOR = 8 AUTO_FACTOR = 8
def k_centroid_downscale(images, width, height, centroids=2): def k_centroid_downscale(images, width, height, centroids=2):
'''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.''' '''k-centroid scaling, based on: https://github.com/Astropulse/stable-diffusion-aseprite/blob/main/scripts/image_server.py.'''
@ -31,13 +34,13 @@ def k_centroid_downscale(images, width, height, centroids=2):
# get most common (median) color # get most common (median) color
color_counts = tile.getcolors() color_counts = tile.getcolors()
most_common_idx = max(color_counts, key=lambda x: x[0])[1] most_common_idx = max(color_counts, key=lambda x: x[0])[1]
downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx*3:(most_common_idx + 1)*3] downscaled[ii, y, x, :] = tile.getpalette()[most_common_idx * 3:(most_common_idx + 1) * 3]
downscaled = downscaled.astype(np.float32) / 255.0 downscaled = downscaled.astype(np.float32) / 255.0
return torch.from_numpy(downscaled) return torch.from_numpy(downscaled)
class ImageKCentroidDownscale: class ImageKCentroidDownscale(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -58,7 +61,8 @@ class ImageKCentroidDownscale:
s = k_centroid_downscale(image, width, height, centroids) s = k_centroid_downscale(image, width, height, centroids)
return (s,) return (s,)
class ImageKCentroidAutoDownscale:
class ImageKCentroidAutoDownscale(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {

View File

@ -10,6 +10,8 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from comfy.nodes.package_typing import CustomNode
def flatten_modules(m): def flatten_modules(m):
'''Return submodules of module m in flattened form.''' '''Return submodules of module m in flattened form.'''
@ -38,7 +40,7 @@ def __replacementConv2DConvForward(self, input: torch.Tensor, weight: torch.Tens
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups) return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
class MakeModelTileable: class MakeModelTileable(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {

View File

@ -7,8 +7,10 @@ import numpy as np
import rembg import rembg
import torch import torch
from comfy.nodes.package_typing import CustomNode
class BinarizeMask:
class BinarizeMask(CustomNode):
'''Binarize (threshold) a mask.''' '''Binarize (threshold) a mask.'''
@classmethod @classmethod
@ -36,7 +38,7 @@ class BinarizeMask:
return (s,) return (s,)
class ImageCutout: class ImageCutout(CustomNode):
'''Perform basic image cutout (adds alpha channel from mask).''' '''Perform basic image cutout (adds alpha channel from mask).'''
@classmethod @classmethod
@ -65,4 +67,3 @@ NODE_CLASS_MAPPINGS = {
"BinarizeMask": BinarizeMask, "BinarizeMask": BinarizeMask,
"ImageCutout": ImageCutout, "ImageCutout": ImageCutout,
} }

View File

@ -9,6 +9,8 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
from comfy.nodes.package_typing import CustomNode
PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../../..', 'palettes') PALETTES_PATH = os.path.join(os.path.dirname(__file__), '../../..', 'palettes')
PAL_EXT = '.png' PAL_EXT = '.png'
@ -18,6 +20,7 @@ QUANTIZE_METHODS = {
'fast_octree': Image.Quantize.FASTOCTREE, 'fast_octree': Image.Quantize.FASTOCTREE,
} }
# Determine optimal number of colors. # Determine optimal number of colors.
# FROM: astropulse/sd-palettize # FROM: astropulse/sd-palettize
# #
@ -59,8 +62,10 @@ def determine_best_k(image, max_k, quantize_method=Image.Quantize.FASTOCTREE):
return best_k return best_k
palette_warned = False palette_warned = False
def list_palettes(): def list_palettes():
global palette_warned global palette_warned
palettes = [] palettes = []
@ -72,7 +77,8 @@ def list_palettes():
pass pass
if not palettes and not palette_warned: if not palettes and not palette_warned:
palette_warned = True palette_warned = True
print("ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.") print(
"ImagePalettize warning: no fixed palettes found. You can put these in the palettes/ directory below the ComfyUI root.")
return palettes return palettes
@ -90,7 +96,7 @@ def load_palette(name):
return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT))) return get_image_colors(Image.open(os.path.join(PALETTES_PATH, name + PAL_EXT)))
class ImagePalettize: class ImagePalettize(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -122,7 +128,7 @@ class ImagePalettize:
if palette not in {'auto_best_k', 'auto_fixed_k'}: if palette not in {'auto_best_k', 'auto_fixed_k'}:
pal_entries = load_palette(palette) pal_entries = load_palette(palette)
k = len(pal_entries) // 3 k = len(pal_entries) // 3
pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette pal_img = Image.new('P', (1, 1)) # image size doesn't matter it only holds the palette
pal_img.putpalette(pal_entries) pal_img.putpalette(pal_entries)
results = [] results = []
@ -143,7 +149,7 @@ class ImagePalettize:
results.append(np.array(i)) results.append(np.array(i))
result = np.array(results).astype(np.float32) / 255.0 result = np.array(results).astype(np.float32) / 255.0
return (torch.from_numpy(result), ) return (torch.from_numpy(result),)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -9,9 +9,12 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
from comfy.nodes.package_typing import CustomNode
MAX_RESOLUTION = 8192 MAX_RESOLUTION = 8192
class ImageSolidColor:
class ImageSolidColor(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -32,7 +35,7 @@ class ImageSolidColor:
def render(self, width, height, r, g, b): def render(self, width, height, r, g, b):
color = torch.tensor([r, g, b]) / 255.0 color = torch.tensor([r, g, b]) / 255.0
result = color.expand(1, height, width, 3) result = color.expand(1, height, width, 3)
return (result, ) return (result,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@ -42,4 +45,3 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSolidColor": "Solid Color", "ImageSolidColor": "Solid Color",
} }

View File

@ -7,11 +7,12 @@ import numpy as np
import rembg import rembg
import torch import torch
from comfy.nodes.package_typing import CustomNode
MODELS = rembg.sessions.sessions_names MODELS = rembg.sessions.sessions_names
class ImageRemoveBackground: class ImageRemoveBackground(CustomNode):
'''Remove background from image (adds an alpha channel)''' '''Remove background from image (adds an alpha channel)'''
@classmethod @classmethod
@ -59,17 +60,18 @@ class ImageRemoveBackground:
i = 255. * i.cpu().numpy() i = 255. * i.cpu().numpy()
i = np.clip(i, 0, 255).astype(np.uint8) i = np.clip(i, 0, 255).astype(np.uint8)
i = rembg.remove(i, i = rembg.remove(i,
alpha_matting=(alpha_matting == "enabled"), alpha_matting=(alpha_matting == "enabled"),
alpha_matting_foreground_threshold=am_foreground_thr, alpha_matting_foreground_threshold=am_foreground_thr,
alpha_matting_background_threshold=am_background_thr, alpha_matting_background_threshold=am_background_thr,
alpha_matting_erode_size=am_erode_size, alpha_matting_erode_size=am_erode_size,
session=session, session=session,
) )
results.append(i.astype(np.float32) / 255.0) results.append(i.astype(np.float32) / 255.0)
s = torch.from_numpy(np.array(results)) s = torch.from_numpy(np.array(results))
return (s,) return (s,)
class ImageEstimateForegroundMask: class ImageEstimateForegroundMask:
''' '''
Return a mask of which pixels are estimated to belong to foreground. Return a mask of which pixels are estimated to belong to foreground.

View File

@ -29,5 +29,4 @@ diffusers>=0.16.1
protobuf==3.20.3 protobuf==3.20.3
rembg rembg
psutil psutil
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl; platform_system == "Windows" bitsandbytes>=0.40.1
bitsandbytes; platform_system != "Windows"

View File

@ -54,6 +54,11 @@ Packages that should have a specific option set when a GPU accelerator is presen
""" """
gpu_accelerated_packages = {"rembg": "rembg[gpu]"} gpu_accelerated_packages = {"rembg": "rembg[gpu]"}
"""
The URL to the bitsandbytes package to use on Windows
"""
bitsandbytes_windows = "https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl"
""" """
Indicates if we're installing an editable (develop) mode package Indicates if we're installing an editable (develop) mode package
""" """
@ -152,6 +157,8 @@ def dependencies() -> [str]:
requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}") requirement = InstallRequirement(Requirement(package), comes_from=f"{package_name}=={version}")
candidate = finder.find_best_candidate(requirement.name, requirement.specifier) candidate = finder.find_best_candidate(requirement.name, requirement.specifier)
if candidate.best_candidate is not None: if candidate.best_candidate is not None:
if requirement.name == "bitsandbytes" and platform.system().lower() == 'windows':
_dependencies[i] = f"{requirement.name} @ {bitsandbytes_windows}"
if gpu_accelerated and requirement.name in gpu_accelerated_packages: if gpu_accelerated and requirement.name in gpu_accelerated_packages:
_dependencies[i] = gpu_accelerated_packages[requirement.name] _dependencies[i] = gpu_accelerated_packages[requirement.name]
if any([url in candidate.best_candidate.link.url for url in _alternative_indices]): if any([url in candidate.best_candidate.link.url for url in _alternative_indices]):