mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 15:20:51 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
db423f8013
@ -87,6 +87,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
@ -110,6 +110,7 @@ class Configuration(dict):
|
||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -184,6 +185,7 @@ class Configuration(dict):
|
||||
self.max_queue_size: int = 65536
|
||||
self.force_channels_last: bool = False
|
||||
self.force_hf_local_dir_mode = False
|
||||
self.preview_size: int = 512
|
||||
|
||||
# from guill
|
||||
self.cache_lru: int = 0
|
||||
|
||||
@ -100,6 +100,38 @@ def _cleanup():
|
||||
|
||||
|
||||
class EmbeddedComfyClient:
|
||||
"""
|
||||
Embedded client for comfy executing prompts as a library.
|
||||
|
||||
This client manages a single-threaded executor to run long-running or blocking tasks
|
||||
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
|
||||
in a dedicated thread for executing prompts and handling server-stub communications.
|
||||
Example usage:
|
||||
|
||||
Asynchronous (non-blocking) usage with async-await:
|
||||
```
|
||||
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
|
||||
# workspace.
|
||||
prompt_dict = {
|
||||
"1": {"class_type": "KSamplerAdvanced", ...}
|
||||
...
|
||||
}
|
||||
# Validate your workflow (the prompt)
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt_dict)
|
||||
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
|
||||
# It does not connect to a remote server.
|
||||
async def main():
|
||||
async with EmbeddedComfyClient() as client:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
print(outputs)
|
||||
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
|
||||
if __name__ == "__main__"
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
"""
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
@ -12,7 +12,7 @@ from .. import model_management
|
||||
from .. import utils
|
||||
import logging
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = 512
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
|
||||
@ -200,6 +200,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
|
||||
@ -6,9 +6,10 @@ from enum import Enum
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict, runtime_checkable
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .queue_types import BinaryEventTypes
|
||||
from ..cli_args_types import Configuration
|
||||
from ..nodes.package_typing import InputTypeSpec
|
||||
|
||||
|
||||
@ -201,4 +202,7 @@ class Executor(Protocol):
|
||||
...
|
||||
|
||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||
...
|
||||
...
|
||||
|
||||
|
||||
ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None, Configuration | None]
|
||||
|
||||
@ -493,14 +493,15 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_flux_instantx(sd):
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
@ -520,6 +521,11 @@ def load_controlnet_flux_instantx(sd):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||
@ -580,13 +586,15 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
|
||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data)
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
import concurrent.futures
|
||||
from typing import Callable
|
||||
|
||||
from pebble import ProcessPool
|
||||
from pebble import ProcessPool, ProcessFuture
|
||||
|
||||
from ..component_model.executor_types import Executor
|
||||
from ..component_model.executor_types import Executor, ExecutePromptArgs
|
||||
|
||||
|
||||
class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
def __init__(self, max_workers: int = 1):
|
||||
super().__init__(max_workers=1)
|
||||
|
||||
|
||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||
if cancel_futures:
|
||||
raise NotImplementedError("cannot cancel futures in this implementation")
|
||||
@ -15,5 +20,17 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
||||
self.stop()
|
||||
return
|
||||
|
||||
def schedule(self, function: Callable,
|
||||
args: list = (),
|
||||
kwargs: dict = {},
|
||||
timeout: float = None) -> ProcessFuture:
|
||||
try:
|
||||
args: ExecutePromptArgs
|
||||
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
super().schedule(function, args, kwargs, timeout)
|
||||
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||
|
||||
@ -1,32 +1,78 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
# modified to support different types of flux controlnets
|
||||
|
||||
import math
|
||||
from typing_extensions import Never
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Never
|
||||
|
||||
from .layers import (timestep_embedding)
|
||||
from .model import Flux
|
||||
from .. import common_dit
|
||||
|
||||
|
||||
class MistolineCondDownsamplBlock(nn.Module):
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
class MistolineControlnetBlock(nn.Module):
|
||||
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.act(self.linear(x))
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.main_model_double = 19
|
||||
self.main_model_single = 38
|
||||
|
||||
self.mistoline = mistoline
|
||||
# add ControlNet blocks
|
||||
if self.mistoline:
|
||||
control_block = lambda: MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
control_block = lambda: operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.controlnet_blocks.append(control_block())
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth_single_blocks):
|
||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||
self.controlnet_single_blocks.append(control_block())
|
||||
|
||||
self.num_union_modes = num_union_modes
|
||||
self.controlnet_mode_embedder = None
|
||||
@ -37,44 +83,44 @@ class ControlNetFlux(Flux):
|
||||
self.latent_input = latent_input
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
if not self.latent_input:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
if self.mistoline:
|
||||
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor | list[Never] | None = None,
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control_type: Tensor | list[Never] | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
if not self.latent_input:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
@ -87,7 +133,7 @@ class ControlNetFlux(Flux):
|
||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||
txt = torch.cat([control_cond, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
@ -104,13 +150,13 @@ class ControlNetFlux(Flux):
|
||||
|
||||
for i in range(len(self.single_blocks)):
|
||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1]:, ...]),)
|
||||
|
||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||
if self.latent_input:
|
||||
out_input = ()
|
||||
for x in controlnet_double:
|
||||
out_input += (x,) * repeat
|
||||
out_input += (x,) * repeat
|
||||
else:
|
||||
out_input = (controlnet_double * repeat)
|
||||
|
||||
@ -120,7 +166,7 @@ class ControlNetFlux(Flux):
|
||||
out_output = ()
|
||||
if self.latent_input:
|
||||
for x in controlnet_single:
|
||||
out_output += (x,) * repeat
|
||||
out_output += (x,) * repeat
|
||||
else:
|
||||
out_output = (controlnet_single * repeat)
|
||||
out["output"] = out_output[:self.main_model_single]
|
||||
@ -130,9 +176,14 @@ class ControlNetFlux(Flux):
|
||||
patch_size = 2
|
||||
if self.latent_input:
|
||||
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
elif self.mistoline:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_cond_block(hint)
|
||||
else:
|
||||
hint = hint * 2.0 - 1.0
|
||||
hint = self.input_hint_block(hint)
|
||||
|
||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if "emb_patch" in transformer_patches:
|
||||
patch = transformer_patches["emb_patch"]
|
||||
for p in patch:
|
||||
emb = p(emb, self.model_channels, transformer_options)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@ -495,7 +495,7 @@ def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
from comfy.component_model.files import get_path_as_dict
|
||||
from ..component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
|
||||
91
comfy_extras/nodes_lora_extract.py
Normal file
91
comfy_extras/nodes_lora_extract.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import os
|
||||
import logging
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
def extract_lora(diff, rank):
|
||||
conv2d = (len(diff.shape) == 4)
|
||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
rank = min(rank, in_dim, out_dim)
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
|
||||
U, S, Vh = torch.linalg.svd(diff.float())
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U, Vh)
|
||||
|
||||
class LoraSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}),
|
||||
},
|
||||
"optional": {"model_diff": ("MODEL",),},
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def save(self, filename_prefix, rank, model_diff=None):
|
||||
if model_diff is None:
|
||||
return {}
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
output_sd = {}
|
||||
prefix_key = "diffusion_model."
|
||||
stored = set()
|
||||
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_key)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
if weight_diff.ndim < 2:
|
||||
continue
|
||||
try:
|
||||
out = extract_lora(weight_diff, rank)
|
||||
output_sd["{}.lora_up.weight".format(k[:-7])] = out[0].contiguous().half().cpu()
|
||||
output_sd["{}.lora_down.weight".format(k[:-7])] = out[1].contiguous().half().cpu()
|
||||
except:
|
||||
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraSave": LoraSave
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user