mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Merge remote-tracking branch 'upstream/master' into addBatchIndex
This commit is contained in:
commit
076f7eaa4b
@ -31,7 +31,7 @@ jobs:
|
|||||||
echo 'import site' >> ./python311._pth
|
echo 'import site' >> ./python311._pth
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||||
ls ../temp_wheel_dir
|
ls ../temp_wheel_dir
|
||||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||||
sed -i '1i../ComfyUI' ./python311._pth
|
sed -i '1i../ComfyUI' ./python311._pth
|
||||||
|
|||||||
@ -29,7 +29,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
|
||||||
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
|
||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
- [Config file](extra_model_paths.yaml.example) to set the search paths for models.
|
||||||
@ -69,7 +70,7 @@ There is a portable standalone build for Windows that should work for running on
|
|||||||
|
|
||||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
|
||||||
|
|
||||||
Just download, extract and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||||
|
|
||||||
#### How do I share models between another UI and ComfyUI?
|
#### How do I share models between another UI and ComfyUI?
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
"num_attention_heads": 20,
|
"num_attention_heads": 20,
|
||||||
"num_hidden_layers": 32,
|
"num_hidden_layers": 32,
|
||||||
"pad_token_id": 1,
|
"pad_token_id": 1,
|
||||||
"projection_dim": 512,
|
"projection_dim": 1280,
|
||||||
"torch_dtype": "float32",
|
"torch_dtype": "float32",
|
||||||
"vocab_size": 49408
|
"vocab_size": 49408
|
||||||
}
|
}
|
||||||
|
|||||||
@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
capture_qkv_bias = {}
|
capture_qkv_bias = {}
|
||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
k.endswith(".self_attn.q_proj.weight")
|
k.endswith(".self_attn.q_proj.weight")
|
||||||
or k.endswith(".self_attn.k_proj.weight")
|
or k.endswith(".self_attn.k_proj.weight")
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from . import utils
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
@ -11,6 +12,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
self.v_prediction = v_prediction
|
self.v_prediction = v_prediction
|
||||||
@ -83,6 +85,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
def process_latent_out(self, latent):
|
def process_latent_out(self, latent):
|
||||||
return self.latent_format.process_out(latent)
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
|
if self.get_dtype() == torch.float16:
|
||||||
|
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||||
|
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||||
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||||
@ -144,10 +156,10 @@ class SDXLRefiner(BaseModel):
|
|||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
out.append(self.embedder(torch.Tensor([aesthetic_score])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
@ -168,11 +180,11 @@ class SDXL(BaseModel):
|
|||||||
|
|
||||||
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
|
||||||
out = []
|
out = []
|
||||||
out.append(self.embedder(torch.Tensor([width])))
|
|
||||||
out.append(self.embedder(torch.Tensor([height])))
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
out.append(self.embedder(torch.Tensor([target_width])))
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
out.append(self.embedder(torch.Tensor([target_height])))
|
out.append(self.embedder(torch.Tensor([target_height])))
|
||||||
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
flat = torch.flatten(torch.cat(out))[None, ]
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|||||||
@ -16,13 +16,11 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
num_res_blocks = 2
|
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
"out_channels": 4,
|
"out_channels": 4,
|
||||||
"num_res_blocks": num_res_blocks,
|
|
||||||
"use_spatial_transformer": True,
|
"use_spatial_transformer": True,
|
||||||
"legacy": False
|
"legacy": False
|
||||||
}
|
}
|
||||||
|
|||||||
@ -139,7 +139,23 @@ else:
|
|||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
|
def is_nvidia():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.cuda:
|
||||||
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
|
try:
|
||||||
|
if is_nvidia():
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
if int(torch_version[0]) >= 2:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -200,6 +216,11 @@ current_gpu_controlnets = []
|
|||||||
|
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
|
def unet_offload_device():
|
||||||
|
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
global current_loaded_model
|
global current_loaded_model
|
||||||
@ -212,10 +233,9 @@ def unload_model():
|
|||||||
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
|
||||||
model_accelerated = False
|
model_accelerated = False
|
||||||
|
|
||||||
#never unload models from GPU on high vram
|
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
current_loaded_model.model.to(unet_offload_device())
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model_patches_to(unet_offload_device())
|
||||||
current_loaded_model.model_patches_to("cpu")
|
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
@ -347,7 +367,7 @@ def pytorch_attention_flash_attention():
|
|||||||
global ENABLE_PYTORCH_ATTENTION
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -438,7 +458,7 @@ def soft_empty_cache():
|
|||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|||||||
57
comfy/sd.py
57
comfy/sd.py
@ -89,8 +89,7 @@ LORA_UNET_MAP_RESNET = {
|
|||||||
"skip_connection": "resnets_{}_conv_shortcut"
|
"skip_connection": "resnets_{}_conv_shortcut"
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_lora(path, to_load):
|
def load_lora(lora, to_load):
|
||||||
lora = utils.load_torch_file(path, safe_load=True)
|
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -223,13 +222,28 @@ def model_lora_keys(model, key_map={}):
|
|||||||
counter += 1
|
counter += 1
|
||||||
counter = 0
|
counter = 0
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
for b in range(24):
|
clip_l_present = False
|
||||||
|
for b in range(32):
|
||||||
for c in LORA_CLIP_MAP:
|
for c in LORA_CLIP_MAP:
|
||||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
key_map[lora_key] = k
|
||||||
|
clip_l_present = True
|
||||||
|
|
||||||
|
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||||
|
if k in sdk:
|
||||||
|
if clip_l_present:
|
||||||
|
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||||
|
else:
|
||||||
|
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
#Locon stuff
|
#Locon stuff
|
||||||
ds_counter = 0
|
ds_counter = 0
|
||||||
@ -486,10 +500,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = model_lora_keys(model.model)
|
key_map = model_lora_keys(model.model)
|
||||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
||||||
loaded = load_lora(lora_path, key_map)
|
loaded = load_lora(lora, key_map)
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
@ -545,11 +559,11 @@ class CLIP:
|
|||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
try:
|
try:
|
||||||
self.patcher.patch_model()
|
self.patch_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
self.patcher.unpatch_model()
|
self.unpatch_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.patcher.unpatch_model()
|
self.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
cond_out = cond
|
cond_out = cond
|
||||||
@ -564,6 +578,15 @@ class CLIP:
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.cond_stage_model.state_dict()
|
||||||
|
|
||||||
|
def patch_model(self):
|
||||||
|
self.patcher.patch_model()
|
||||||
|
|
||||||
|
def unpatch_model(self):
|
||||||
|
self.patcher.unpatch_model()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
@ -665,6 +688,10 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.first_stage_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
@ -1114,6 +1141,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
model = model_config.get_model(sd)
|
model = model_config.get_model(sd)
|
||||||
|
model = model.to(model_management.unet_offload_device())
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1135,3 +1163,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae, clipvision)
|
return (ModelPatcher(model), clip, vae, clipvision)
|
||||||
|
|
||||||
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
|
try:
|
||||||
|
model.patch_model()
|
||||||
|
clip.patch_model()
|
||||||
|
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||||
|
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
raise e
|
||||||
|
|||||||
@ -95,7 +95,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
if len(embedding_weights) > 0:
|
if len(embedding_weights) > 0:
|
||||||
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1])
|
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
|
||||||
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
|
||||||
n = token_dict_size
|
n = token_dict_size
|
||||||
for x in embedding_weights:
|
for x in embedding_weights:
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from . import sdxl_clip
|
|||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
|
||||||
|
from . import diffusers_convert
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"context_dim": 768,
|
"context_dim": 768,
|
||||||
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
|
|||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
replace_prefix[""] = "cond_stage_model.model."
|
||||||
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||||
|
|
||||||
@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||||
|
|
||||||
@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE):
|
|||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
keys_to_replace = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
for k in state_dict:
|
||||||
|
if k.startswith("clip_l"):
|
||||||
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||||
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
|
|||||||
@ -64,3 +64,15 @@ class BASE:
|
|||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "model.diffusion_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_vae_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "first_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import torch
|
|||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
if metadata is not None:
|
||||||
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||||
|
else:
|
||||||
|
safetensors.torch.save_file(sd, ckpt)
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def convert_sd_to(state_dict, dtype):
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
state_dict[k] = state_dict[k].to(dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
with open(safetensors_path, "rb") as f:
|
with open(safetensors_path, "rb") as f:
|
||||||
header = f.read(8)
|
header = f.read(8)
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -10,7 +14,7 @@ class ModelMergeSimple:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2, ratio):
|
def merge(self, model1, model2, ratio):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
@ -31,7 +35,7 @@ class ModelMergeBlocks:
|
|||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "merge"
|
FUNCTION = "merge"
|
||||||
|
|
||||||
CATEGORY = "_for_testing/model_merging"
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
def merge(self, model1, model2, **kwargs):
|
def merge(self, model1, model2, **kwargs):
|
||||||
m = model1.clone()
|
m = model1.clone()
|
||||||
@ -49,7 +53,43 @@ class ModelMergeBlocks:
|
|||||||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class CheckpointSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"prompt": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
"ModelMergeBlocks": ModelMergeBlocks
|
"ModelMergeBlocks": ModelMergeBlocks,
|
||||||
|
"CheckpointSave": CheckpointSave,
|
||||||
}
|
}
|
||||||
|
|||||||
24
execution.py
24
execution.py
@ -110,7 +110,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
@ -125,7 +125,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||||
if result[0] is not True:
|
if result[0] is not True:
|
||||||
# Another node failed further upstream
|
# Another node failed further upstream
|
||||||
return result
|
return result
|
||||||
@ -136,7 +136,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
obj = class_def()
|
|
||||||
|
obj = object_storage.get((unique_id, class_type), None)
|
||||||
|
if obj is None:
|
||||||
|
obj = class_def()
|
||||||
|
object_storage[(unique_id, class_type)] = obj
|
||||||
|
|
||||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||||
outputs[unique_id] = output_data
|
outputs[unique_id] = output_data
|
||||||
@ -256,6 +260,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
|
self.object_storage = {}
|
||||||
self.outputs_ui = {}
|
self.outputs_ui = {}
|
||||||
self.old_prompt = {}
|
self.old_prompt = {}
|
||||||
self.server = server
|
self.server = server
|
||||||
@ -322,6 +327,17 @@ class PromptExecutor:
|
|||||||
for o in to_delete:
|
for o in to_delete:
|
||||||
d = self.outputs.pop(o)
|
d = self.outputs.pop(o)
|
||||||
del d
|
del d
|
||||||
|
to_delete = []
|
||||||
|
for o in self.object_storage:
|
||||||
|
if o[0] not in prompt:
|
||||||
|
to_delete += [o]
|
||||||
|
else:
|
||||||
|
p = prompt[o[0]]
|
||||||
|
if o[1] != p['class_type']:
|
||||||
|
to_delete += [o]
|
||||||
|
for o in to_delete:
|
||||||
|
d = self.object_storage.pop(o)
|
||||||
|
del d
|
||||||
|
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
@ -349,7 +365,7 @@ class PromptExecutor:
|
|||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
# the actual SD code, instead it will report the node where the
|
# the actual SD code, instead it will report the node where the
|
||||||
# error was raised
|
# error was raised
|
||||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||||
if success is not True:
|
if success is not True:
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|||||||
@ -8,7 +8,9 @@ a111:
|
|||||||
checkpoints: models/Stable-diffusion
|
checkpoints: models/Stable-diffusion
|
||||||
configs: models/Stable-diffusion
|
configs: models/Stable-diffusion
|
||||||
vae: models/VAE
|
vae: models/VAE
|
||||||
loras: models/Lora
|
loras: |
|
||||||
|
models/Lora
|
||||||
|
models/LyCORIS
|
||||||
upscale_models: |
|
upscale_models: |
|
||||||
models/ESRGAN
|
models/ESRGAN
|
||||||
models/SwinIR
|
models/SwinIR
|
||||||
@ -21,5 +23,3 @@ a111:
|
|||||||
# checkpoints: models/checkpoints
|
# checkpoints: models/checkpoints
|
||||||
# gligen: models/gligen
|
# gligen: models/gligen
|
||||||
# custom_nodes: path/custom_nodes
|
# custom_nodes: path/custom_nodes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
40
nodes.py
40
nodes.py
@ -148,6 +148,25 @@ class ConditioningSetMask:
|
|||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class ConditioningZeroOut:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"conditioning": ("CONDITIONING", )}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "zero_out"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def zero_out(self, conditioning):
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
if "pooled_output" in d:
|
||||||
|
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
|
||||||
|
n = [torch.zeros_like(t[0]), d]
|
||||||
|
c.append(n)
|
||||||
|
return (c, )
|
||||||
|
|
||||||
class VAEDecode:
|
class VAEDecode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -286,8 +305,7 @@ class SaveLatent:
|
|||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"]
|
||||||
output["latent_format_version_0"] = torch.tensor([])
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@ -416,6 +434,9 @@ class CLIPSetLastLayer:
|
|||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class LoraLoader:
|
class LoraLoader:
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded_lora = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
@ -434,7 +455,18 @@ class LoraLoader:
|
|||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
lora = None
|
||||||
|
if self.loaded_lora is not None:
|
||||||
|
if self.loaded_lora[0] == lora_path:
|
||||||
|
lora = self.loaded_lora[1]
|
||||||
|
else:
|
||||||
|
del self.loaded_lora
|
||||||
|
|
||||||
|
if lora is None:
|
||||||
|
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||||
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@ -1351,6 +1383,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
"LoadLatent": LoadLatent,
|
"LoadLatent": LoadLatent,
|
||||||
"SaveLatent": SaveLatent,
|
"SaveLatent": SaveLatent,
|
||||||
|
|
||||||
|
"ConditioningZeroOut": ConditioningZeroOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@ -144,6 +144,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# ESRGAN upscale model\n",
|
"# ESRGAN upscale model\n",
|
||||||
|
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@ -1484,7 +1484,7 @@ export class ComfyApp {
|
|||||||
this.loadGraphData(JSON.parse(reader.result));
|
this.loadGraphData(JSON.parse(reader.result));
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
} else if (file.name?.endsWith(".latent")) {
|
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||||
const info = await getLatentMetadata(file);
|
const info = await getLatentMetadata(file);
|
||||||
if (info.workflow) {
|
if (info.workflow) {
|
||||||
this.loadGraphData(JSON.parse(info.workflow));
|
this.loadGraphData(JSON.parse(info.workflow));
|
||||||
|
|||||||
@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
|
|||||||
const dataView = new DataView(safetensorsData.buffer);
|
const dataView = new DataView(safetensorsData.buffer);
|
||||||
let header_size = dataView.getUint32(0, true);
|
let header_size = dataView.getUint32(0, true);
|
||||||
let offset = 8;
|
let offset = 8;
|
||||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
|
||||||
r(header.__metadata__);
|
r(header.__metadata__);
|
||||||
};
|
};
|
||||||
|
|
||||||
reader.readAsArrayBuffer(file);
|
var slice = file.slice(0, 1024 * 1024 * 4);
|
||||||
|
reader.readAsArrayBuffer(slice);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -545,7 +545,7 @@ export class ComfyUI {
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png,.latent",
|
accept: ".json,image/png,.latent,.safetensors",
|
||||||
style: {display: "none"},
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user