mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Created a PathServer and added a1111 path option to use models saved on a1111 directory
This commit is contained in:
parent
51bbbf8d64
commit
2b85c4d3dd
15
main.py
15
main.py
@ -5,6 +5,8 @@ import shutil
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from path_server import PathServer
|
||||
|
||||
if os.name == "nt":
|
||||
import logging
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@ -83,7 +85,18 @@ if __name__ == "__main__":
|
||||
try:
|
||||
p_index = sys.argv.index('--port')
|
||||
port = int(sys.argv[p_index + 1])
|
||||
except:
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
print(f"{sys.argv = }")
|
||||
a1111_index = sys.argv.index('--a1111')
|
||||
print(f"{a1111_index = }")
|
||||
a1111 = sys.argv[a1111_index + 1]
|
||||
print(f"{a1111 = }")
|
||||
PathServer().set_a1111_path(a1111)
|
||||
print(f"{PathServer().get('checkpoints') = }")
|
||||
except ValueError as e:
|
||||
pass
|
||||
|
||||
if '--quick-test-for-ci' in sys.argv:
|
||||
|
||||
78
nodes.py
78
nodes.py
@ -23,6 +23,8 @@ import comfy_extras.clip_vision
|
||||
import model_management
|
||||
import importlib
|
||||
|
||||
from path_server import PathServer
|
||||
|
||||
supported_ckpt_extensions = ['.ckpt', '.pth']
|
||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
|
||||
try:
|
||||
@ -207,32 +209,25 @@ class VAEEncodeForInpaint:
|
||||
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
|
||||
|
||||
class CheckpointLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
config_dir = os.path.join(models_dir, "configs")
|
||||
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||
embedding_directory = os.path.join(models_dir, "embeddings")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (filter_files_extensions(recursive_search(s.config_dir), '.yaml'), ),
|
||||
"ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), )}}
|
||||
return {"required": { "config_name": (filter_files_extensions(recursive_search(PathServer().get('configs')), '.yaml'), ),
|
||||
"ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), )}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||
config_path = os.path.join(self.config_dir, config_name)
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory)
|
||||
config_path = os.path.join(PathServer().get('configs'), config_name)
|
||||
ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name)
|
||||
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings'))
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
ckpt_dir = os.path.join(models_dir, "checkpoints")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions), ),
|
||||
print(f"{PathServer().get('checkpoints') = }")
|
||||
return {"required": { "ckpt_name": (filter_files_extensions(recursive_search(PathServer().get('checkpoints')), supported_ckpt_extensions), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -240,8 +235,9 @@ class CheckpointLoaderSimple:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
print(f"{PathServer().get('checkpoints') = }")
|
||||
ckpt_path = os.path.join(PathServer().get('checkpoints'), ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=PathServer().get('embeddings'))
|
||||
return out
|
||||
|
||||
class CLIPSetLastLayer:
|
||||
@ -261,13 +257,11 @@ class CLIPSetLastLayer:
|
||||
return (clip,)
|
||||
|
||||
class LoraLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
lora_dir = os.path.join(models_dir, "loras")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"clip": ("CLIP", ),
|
||||
"lora_name": (filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions), ),
|
||||
"lora_name": (filter_files_extensions(recursive_search(PathServer().get('loras')), supported_pt_extensions), ),
|
||||
"strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
@ -277,16 +271,14 @@ class LoraLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||
lora_path = os.path.join(self.lora_dir, lora_name)
|
||||
lora_path = os.path.join(PathServer().get('loras'), lora_name)
|
||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class VAELoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
vae_dir = os.path.join(models_dir, "vae")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae_name": (filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "vae_name": (filter_files_extensions(recursive_search(PathServer().get('vae')), supported_pt_extensions), )}}
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
|
||||
@ -294,16 +286,14 @@ class VAELoader:
|
||||
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
vae_path = os.path.join(self.vae_dir, vae_name)
|
||||
vae_path = os.path.join(PathServer().get('vae'), vae_name)
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
@ -311,17 +301,15 @@ class ControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path)
|
||||
return (controlnet,)
|
||||
|
||||
class DiffControlNetLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
controlnet_dir = os.path.join(models_dir, "controlnet")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
|
||||
"control_net_name": (filter_files_extensions(recursive_search(PathServer().get('controlnet')), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
@ -329,7 +317,7 @@ class DiffControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, model, control_net_name):
|
||||
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
|
||||
controlnet_path = os.path.join(PathServer().get('controlnet'), control_net_name)
|
||||
controlnet = comfy.sd.load_controlnet(controlnet_path, model)
|
||||
return (controlnet,)
|
||||
|
||||
@ -361,11 +349,9 @@ class ControlNetApply:
|
||||
return (c, )
|
||||
|
||||
class T2IAdapterLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
t2i_adapter_dir = os.path.join(models_dir, "t2i_adapter")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(s.t2i_adapter_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "t2i_adapter_name": (filter_files_extensions(recursive_search(PathServer().get('t2i_adapter')), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_t2i_adapter"
|
||||
@ -373,16 +359,14 @@ class T2IAdapterLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_t2i_adapter(self, t2i_adapter_name):
|
||||
t2i_path = os.path.join(self.t2i_adapter_dir, t2i_adapter_name)
|
||||
t2i_path = os.path.join(PathServer().get('t2i_adapter'), t2i_adapter_name)
|
||||
t2i_adapter = comfy.sd.load_t2i_adapter(t2i_path)
|
||||
return (t2i_adapter,)
|
||||
|
||||
class CLIPLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip')), supported_pt_extensions), ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -390,16 +374,14 @@ class CLIPLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory)
|
||||
clip_path = os.path.join(PathServer().get('clip'), clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=PathServer().get('embeddings'))
|
||||
return (clip,)
|
||||
|
||||
class CLIPVisionLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
clip_dir = os.path.join(models_dir, "clip_vision")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions), ),
|
||||
return {"required": { "clip_name": (filter_files_extensions(recursive_search(PathServer().get('clip_vision')), supported_pt_extensions), ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP_VISION",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -407,7 +389,7 @@ class CLIPVisionLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = os.path.join(self.clip_dir, clip_name)
|
||||
clip_path = os.path.join(PathServer().get('clip_vision'), clip_name)
|
||||
clip_vision = comfy_extras.clip_vision.load(clip_path)
|
||||
return (clip_vision,)
|
||||
|
||||
@ -427,11 +409,9 @@ class CLIPVisionEncode:
|
||||
return (output,)
|
||||
|
||||
class StyleModelLoader:
|
||||
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
|
||||
style_model_dir = os.path.join(models_dir, "style_models")
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(s.style_model_dir), supported_pt_extensions), )}}
|
||||
return {"required": { "style_model_name": (filter_files_extensions(recursive_search(PathServer().get('style_models')), supported_pt_extensions), )}}
|
||||
|
||||
RETURN_TYPES = ("STYLE_MODEL",)
|
||||
FUNCTION = "load_style_model"
|
||||
@ -439,7 +419,7 @@ class StyleModelLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_style_model(self, style_model_name):
|
||||
style_model_path = os.path.join(self.style_model_dir, style_model_name)
|
||||
style_model_path = os.path.join(PathServer().get('style_models'), style_model_name)
|
||||
style_model = comfy.sd.load_style_model(style_model_path)
|
||||
return (style_model,)
|
||||
|
||||
|
||||
37
path_server.py
Normal file
37
path_server.py
Normal file
@ -0,0 +1,37 @@
|
||||
from singleton_decorator import singleton
|
||||
import os
|
||||
|
||||
@singleton
|
||||
class PathServer():
|
||||
|
||||
def __init__(self):
|
||||
self.paths = {
|
||||
'checkpoints': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'checkpoints'),
|
||||
'clip': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip'),
|
||||
'clip_vision': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'clip_vision'),
|
||||
'configs': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'configs'),
|
||||
'controlnet': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'controlnet'),
|
||||
'embeddings': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'embeddings'),
|
||||
'loras': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'loras'),
|
||||
'style_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'style_models'),
|
||||
't2i_adapter': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 't2i_adapter'),
|
||||
'upscale_models': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'upscale_models'),
|
||||
'vae': os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models', 'vae'),
|
||||
}
|
||||
|
||||
def set_a1111_path(self, a1111_path):
|
||||
self.paths['checkpoints'] = os.path.join(a1111_path, 'models', 'Stable-diffusion')
|
||||
self.paths['clip'] = os.path.join(a1111_path, 'models', 'clip-interrogator')
|
||||
# self.paths['clip_vision'] = os.path.join(a1111_path, 'models', '')
|
||||
self.paths['configs'] = os.path.join(a1111_path, 'models', 'Stable-diffusion')
|
||||
self.paths['controlnet'] = os.path.join(a1111_path, 'models', 'ControlNet')
|
||||
self.paths['embeddings'] = os.path.join(a1111_path, 'embeddings')
|
||||
self.paths['loras'] = os.path.join(a1111_path, 'models', 'Lora')
|
||||
#self.paths['style_models'] = os.path.join(a1111_path, 'models', '')
|
||||
#self.paths['t2i_adapter'] = os.path.join(a1111_path, 'models', '')
|
||||
self.paths['upscale_models'] = os.path.join(a1111_path, 'models', 'ESRGAN')
|
||||
self.paths['vae'] = os.path.join(a1111_path, 'models', 'VAE')
|
||||
|
||||
def get(self, key):
|
||||
return self.paths[key]
|
||||
|
||||
@ -9,3 +9,5 @@ pytorch_lightning
|
||||
aiohttp
|
||||
accelerate
|
||||
pyyaml
|
||||
singleton-decorator
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user