Created a PathServer and added a1111 path option to use models saved on a1111 directory

This commit is contained in:
franky 2023-03-18 09:35:15 +01:00
parent 51bbbf8d64
commit 2b85c4d3dd
4 changed files with 82 additions and 50 deletions

15
main.py
View File

@ -5,6 +5,8 @@ import shutil
import threading
import asyncio
from path_server import PathServer
if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
@ -83,7 +85,18 @@ if __name__ == "__main__":
try:
p_index = sys.argv.index('--port')
port = int(sys.argv[p_index + 1])
except:
except ValueError as e:
pass
try:
print(f"{sys.argv = }")
a1111_index = sys.argv.index('--a1111')
print(f"{a1111_index = }")
a1111 = sys.argv[a1111_index + 1]
print(f"{a1111 = }")
PathServer().set_a1111_path(a1111)
print(f"{PathServer().get('checkpoints') = }")
except ValueError as e:
pass
if '--quick-test-for-ci' in sys.argv:

View File

@ -23,6 +23,8 @@ import comfy_extras.clip_vision
import model_management
import importlib
from path_server import PathServer
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try:
@ -207,32 +209,25 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
class CheckpointLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
config_dir = os.path.join(models_dir, "configs")
ckpt_dir = os.path.join(models_dir, "checkpoints")
embedding_directory = os.path.join(models_dir, "embeddings")
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
return {"required": { "config_name": (filter_files_extensions(recursive_search(PathServer().get('configs')), '.yaml'), ),
"ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), )}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = os.path.join(self.config_dir, config_name)
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
config_path = os.path.join(PathServer().get('configs'), config_name)
ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings'))
class CheckpointLoaderSimple:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
ckpt_dir = os.path.join(models_dir, "checkpoints")
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
print(f"{PathServer().get('checkpoints') = }")
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), ),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -240,8 +235,9 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
print(f"{PathServer().get('checkpoints') = }")
ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings'))
return out
class CLIPSetLastLayer:
@ -261,13 +257,11 @@ class CLIPSetLastLayer:
return (clip,)
class LoraLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
lora_dir = os.path.join(models_dir, "loras")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
"lora_name": (filter_files_extensions(recursive_search(PathServer().get('loras')), supported_pt_extensions), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
@ -277,16 +271,14 @@ class LoraLoader:
CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
lora_path = os.path.join(self.lora_dir, lora_name)
lora_path = os.path.join(PathServer().get('loras'), lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora)
class VAELoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
vae_dir = os.path.join(models_dir, "vae")
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
return {"required": { "vae_name": (filter_files_extensions(recursive_search(PathServer().get('vae')), supported_pt_extensions), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -294,16 +286,14 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = os.path.join(self.vae_dir, vae_name)
vae_path = os.path.join(PathServer().get('vae'), vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,)
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -311,17 +301,15 @@ class ControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,)
class DiffControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
"control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
@ -329,7 +317,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
return (controlnet,)
@ -361,11 +349,9 @@ class ControlNetApply:
return (c, )
class T2IAdapterLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
@classmethod
def INPUT_TYPES(s):
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(PathServer().get('t2i_adapter')), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_t2i_adapter"
@ -373,16 +359,14 @@ class T2IAdapterLoader:
CATEGORY = "loaders"
def load_t2i_adapter(self, t2i_adapter_name):
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
t2i_path = os.path.join(PathServer().get('t2i_adapter'), t2i_adapter_name)
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
return (t2i_adapter,)
class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip')), supported_pt_extensions), ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@ -390,16 +374,14 @@ class CLIPLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = os.path.join(self.clip_dir, clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
clip_path = os.path.join(PathServer().get('clip'), clip_name)
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=PathServer().get('embeddings'))
return (clip,)
class CLIPVisionLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip_vision")
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip_vision')), supported_pt_extensions), ),
}}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
@ -407,7 +389,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = os.path.join(self.clip_dir, clip_name)
clip_path = os.path.join(PathServer().get('clip_vision'), clip_name)
clip_vision = comfy_extras.clip_vision.load(clip_path)
return (clip_vision,)
@ -427,11 +409,9 @@ class CLIPVisionEncode:
return (output,)
class StyleModelLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
style_model_dir = os.path.join(models_dir, "style_models")
@classmethod
def INPUT_TYPES(s):
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(PathServer().get('style_models')), supported_pt_extensions), )}}
RETURN_TYPES = ("STYLE_MODEL",)
FUNCTION = "load_style_model"
@ -439,7 +419,7 @@ class StyleModelLoader:
CATEGORY = "loaders"
def load_style_model(self, style_model_name):
style_model_path = os.path.join(self.style_model_dir, style_model_name)
style_model_path = os.path.join(PathServer().get('style_models'), style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
return (style_model,)

37
path_server.py Normal file
View File

@ -0,0 +1,37 @@
from singleton_decorator import singleton
import os
@singleton
class PathServer():
def __init__(self):
self.paths = {
'checkpoints': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'checkpoints'),
'clip': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip'),
'clip_vision': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip_vision'),
'configs': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'configs'),
'controlnet': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'controlnet'),
'embeddings': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'embeddings'),
'loras': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'loras'),
'style_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'style_models'),
't2i_adapter': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 't2i_adapter'),
'upscale_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'upscale_models'),
'vae': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'vae'),
}
def set_a1111_path(self, a1111_path):
self.paths['checkpoints'] = os.path.join(a1111_path, 'models', 'Stable-diffusion')
self.paths['clip'] = os.path.join(a1111_path, 'models', 'clip-interrogator')
# self.paths['clip_vision'] = os.path.join(a1111_path, 'models', '')
self.paths['configs'] = os.path.join(a1111_path, 'models', 'Stable-diffusion')
self.paths['controlnet'] = os.path.join(a1111_path, 'models', 'ControlNet')
self.paths['embeddings'] = os.path.join(a1111_path, 'embeddings')
self.paths['loras'] = os.path.join(a1111_path, 'models', 'Lora')
#self.paths['style_models'] = os.path.join(a1111_path, 'models', '')
#self.paths['t2i_adapter'] = os.path.join(a1111_path, 'models', '')
self.paths['upscale_models'] = os.path.join(a1111_path, 'models', 'ESRGAN')
self.paths['vae'] = os.path.join(a1111_path, 'models', 'VAE')
def get(self, key):
return self.paths[key]

View File

@ -9,3 +9,5 @@ pytorch_lightning
aiohttp
accelerate
pyyaml
singleton-decorator