wip deep_floyd nodes

This commit is contained in:
Benjamin Berman 2023-08-15 18:50:52 -07:00
parent 8ecb5c11e9
commit cd53b3404c
2 changed files with 21 additions and 11 deletions

1
.gitignore vendored
View File

@ -8,6 +8,7 @@
![Cc]ustom_nodes/__init__.py
!/custom_nodes/example_node.py.example
**/put*here
![Mm]odels/deepfloyd/put_deepfloyd_repos_here
/extra_model_paths.yaml
/.vs
.idea/

View File

@ -7,7 +7,7 @@ import torch
import torchvision.transforms.functional as TF
from diffusers import DiffusionPipeline, IFPipeline, StableDiffusionUpscalePipeline, IFSuperResolutionPipeline
from diffusers.utils import is_accelerate_available, is_accelerate_version
from transformers import T5EncoderModel
from transformers import T5EncoderModel, BitsAndBytesConfig
from comfy.model_management import throw_exception_if_processing_interrupted, get_torch_device, cpu_state, CPUState
# todo: this relies on the setup-py cleanup fork
@ -89,8 +89,7 @@ class Loader:
return {
"required": {
"model_name": (Loader._MODELS, {"default": "I-M"}),
"load_in_8bit": ([False, True], {"default": False}),
"device": ("STRING", {"default": ""}),
"quantization": (list(Loader._QUANTIZATIONS.keys()), {"default": "16-bit"}),
}
}
@ -100,8 +99,19 @@ class Loader:
_MODELS = ["I-M", "I-L", "I-XL", "II-M", "II-L", "III", "t5"]
_QUANTIZATIONS = {
"4-bit": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
),
"8-bit": BitsAndBytesConfig(
load_in_8bit=True,
),
"16-bit": None,
}
# todo: correctly use load_in_8bit
def process(self, model_name: str, load_in_8bit: bool, device: str):
def process(self, model_name: str, quantization: str):
assert model_name in Loader._MODELS
model_v: DiffusionPipeline
@ -113,18 +123,17 @@ class Loader:
"feature_extractor": None,
"safety_checker": None,
"watermarker": None,
"load_in_8bit": load_in_8bit,
# todo: fix diffusers when using device_map auto on multi-gpu setups, layers are not assigned to different devices correctly
"device_map": None if device else "auto"
"device_map": None
}
if Loader._QUANTIZATIONS[quantization] is not None:
kwargs['quantization_config'] = Loader._QUANTIZATIONS[quantization]
if model_name == "t5":
# find any valid IF model
model_path = next(os.path.dirname(file) for file in _find_files(_model_base_path, "model_index.json") if
any(x == T5EncoderModel.__name__ for x in
json.load(open(file, 'r'))["text_encoder"]))
# todo: this must use load_in_8bit correctly
# kwargs["text_encoder"] = text_encoder
kwargs["unet"] = None
elif model_name == "III":
model_path = f"{_model_base_path}/stable-diffusion-x4-upscaler"
@ -138,8 +147,8 @@ class Loader:
**kwargs
)
if device:
model_v = model_v.to(device)
device = get_torch_device()
model_v = model_v.to(device)
_cpu_offload(model_v, gpu_id=model_v.device.index)