mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
db423f8013
@ -87,6 +87,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
|||||||
@ -110,6 +110,7 @@ class Configuration(dict):
|
|||||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||||
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
|
||||||
|
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -184,6 +185,7 @@ class Configuration(dict):
|
|||||||
self.max_queue_size: int = 65536
|
self.max_queue_size: int = 65536
|
||||||
self.force_channels_last: bool = False
|
self.force_channels_last: bool = False
|
||||||
self.force_hf_local_dir_mode = False
|
self.force_hf_local_dir_mode = False
|
||||||
|
self.preview_size: int = 512
|
||||||
|
|
||||||
# from guill
|
# from guill
|
||||||
self.cache_lru: int = 0
|
self.cache_lru: int = 0
|
||||||
|
|||||||
@ -100,6 +100,38 @@ def _cleanup():
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddedComfyClient:
|
class EmbeddedComfyClient:
|
||||||
|
"""
|
||||||
|
Embedded client for comfy executing prompts as a library.
|
||||||
|
|
||||||
|
This client manages a single-threaded executor to run long-running or blocking tasks
|
||||||
|
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
|
||||||
|
in a dedicated thread for executing prompts and handling server-stub communications.
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
Asynchronous (non-blocking) usage with async-await:
|
||||||
|
```
|
||||||
|
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
|
||||||
|
# workspace.
|
||||||
|
prompt_dict = {
|
||||||
|
"1": {"class_type": "KSamplerAdvanced", ...}
|
||||||
|
...
|
||||||
|
}
|
||||||
|
# Validate your workflow (the prompt)
|
||||||
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
|
prompt = Prompt.validate(prompt_dict)
|
||||||
|
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
|
||||||
|
# It does not connect to a remote server.
|
||||||
|
async def main():
|
||||||
|
async with EmbeddedComfyClient() as client:
|
||||||
|
outputs = await client.queue_prompt(prompt)
|
||||||
|
print(outputs)
|
||||||
|
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
|
||||||
|
if __name__ == "__main__"
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to use this in blocking methods, learn more about asyncio online.
|
||||||
|
"""
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
||||||
self._progress_handler = progress_handler or ServerStub()
|
self._progress_handler = progress_handler or ServerStub()
|
||||||
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from .. import model_management
|
|||||||
from .. import utils
|
from .. import utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||||
|
|
||||||
|
|
||||||
def preview_to_image(latent_image):
|
def preview_to_image(latent_image):
|
||||||
|
|||||||
@ -200,6 +200,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
|||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||||
|
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
|
|||||||
@ -6,9 +6,10 @@ from enum import Enum
|
|||||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict, runtime_checkable
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from .queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
|
from ..cli_args_types import Configuration
|
||||||
from ..nodes.package_typing import InputTypeSpec
|
from ..nodes.package_typing import InputTypeSpec
|
||||||
|
|
||||||
|
|
||||||
@ -201,4 +202,7 @@ class Executor(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None, Configuration | None]
|
||||||
|
|||||||
@ -493,14 +493,15 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_flux_instantx(sd):
|
def load_controlnet_flux_instantx(sd):
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
@ -520,6 +521,11 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mistoline(sd):
|
||||||
|
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||||
@ -580,13 +586,15 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data)
|
return load_controlnet_flux_instantx(controlnet_data)
|
||||||
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
|
||||||
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
|
|||||||
@ -1,11 +1,16 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from pebble import ProcessPool
|
from pebble import ProcessPool, ProcessFuture
|
||||||
|
|
||||||
from ..component_model.executor_types import Executor
|
from ..component_model.executor_types import Executor, ExecutePromptArgs
|
||||||
|
|
||||||
|
|
||||||
class ProcessPoolExecutor(ProcessPool, Executor):
|
class ProcessPoolExecutor(ProcessPool, Executor):
|
||||||
|
def __init__(self, max_workers: int = 1):
|
||||||
|
super().__init__(max_workers=1)
|
||||||
|
|
||||||
|
|
||||||
def shutdown(self, wait=True, *, cancel_futures=False):
|
def shutdown(self, wait=True, *, cancel_futures=False):
|
||||||
if cancel_futures:
|
if cancel_futures:
|
||||||
raise NotImplementedError("cannot cancel futures in this implementation")
|
raise NotImplementedError("cannot cancel futures in this implementation")
|
||||||
@ -15,5 +20,17 @@ class ProcessPoolExecutor(ProcessPool, Executor):
|
|||||||
self.stop()
|
self.stop()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def schedule(self, function: Callable,
|
||||||
|
args: list = (),
|
||||||
|
kwargs: dict = {},
|
||||||
|
timeout: float = None) -> ProcessFuture:
|
||||||
|
try:
|
||||||
|
args: ExecutePromptArgs
|
||||||
|
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
super().schedule(function, args, kwargs, timeout)
|
||||||
|
|
||||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||||
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
|
||||||
|
|||||||
@ -1,32 +1,78 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
# modified to support different types of flux controlnets
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing_extensions import Never
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from typing_extensions import Never
|
||||||
|
|
||||||
from .layers import (timestep_embedding)
|
from .layers import (timestep_embedding)
|
||||||
from .model import Flux
|
from .model import Flux
|
||||||
from .. import common_dit
|
from .. import common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class MistolineCondDownsamplBlock(nn.Module):
|
||||||
|
def __init__(self, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.encoder(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MistolineControlnetBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
self.main_model_single = 38
|
self.main_model_single = 38
|
||||||
|
|
||||||
|
self.mistoline = mistoline
|
||||||
# add ControlNet blocks
|
# add ControlNet blocks
|
||||||
|
if self.mistoline:
|
||||||
|
control_block = lambda: MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
else:
|
||||||
|
control_block = lambda: operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth):
|
for _ in range(self.params.depth):
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
self.controlnet_blocks.append(control_block())
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
for _ in range(self.params.depth_single_blocks):
|
for _ in range(self.params.depth_single_blocks):
|
||||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
self.controlnet_single_blocks.append(control_block())
|
||||||
|
|
||||||
self.num_union_modes = num_union_modes
|
self.num_union_modes = num_union_modes
|
||||||
self.controlnet_mode_embedder = None
|
self.controlnet_mode_embedder = None
|
||||||
@ -37,44 +83,44 @@ class ControlNetFlux(Flux):
|
|||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
if not self.latent_input:
|
if not self.latent_input:
|
||||||
self.input_hint_block = nn.Sequential(
|
if self.mistoline:
|
||||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
|
||||||
nn.SiLU(),
|
else:
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
self.input_hint_block = nn.Sequential(
|
||||||
nn.SiLU(),
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
nn.SiLU(),
|
||||||
nn.SiLU(),
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
nn.SiLU(),
|
||||||
)
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
img: Tensor,
|
img: Tensor,
|
||||||
img_ids: Tensor,
|
img_ids: Tensor,
|
||||||
controlnet_cond: Tensor,
|
controlnet_cond: Tensor,
|
||||||
txt: Tensor,
|
txt: Tensor,
|
||||||
txt_ids: Tensor,
|
txt_ids: Tensor,
|
||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
control_type: Tensor | list[Never] | None = None,
|
control_type: Tensor | list[Never] | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
if not self.latent_input:
|
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
||||||
|
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
img = img + controlnet_cond
|
img = img + controlnet_cond
|
||||||
@ -87,7 +133,7 @@ class ControlNetFlux(Flux):
|
|||||||
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
txt = torch.cat([control_cond, txt], dim=1)
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
@ -104,13 +150,13 @@ class ControlNetFlux(Flux):
|
|||||||
|
|
||||||
for i in range(len(self.single_blocks)):
|
for i in range(len(self.single_blocks)):
|
||||||
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
||||||
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1]:, ...]),)
|
||||||
|
|
||||||
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
out_input = ()
|
out_input = ()
|
||||||
for x in controlnet_double:
|
for x in controlnet_double:
|
||||||
out_input += (x,) * repeat
|
out_input += (x,) * repeat
|
||||||
else:
|
else:
|
||||||
out_input = (controlnet_double * repeat)
|
out_input = (controlnet_double * repeat)
|
||||||
|
|
||||||
@ -120,7 +166,7 @@ class ControlNetFlux(Flux):
|
|||||||
out_output = ()
|
out_output = ()
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
for x in controlnet_single:
|
for x in controlnet_single:
|
||||||
out_output += (x,) * repeat
|
out_output += (x,) * repeat
|
||||||
else:
|
else:
|
||||||
out_output = (controlnet_single * repeat)
|
out_output = (controlnet_single * repeat)
|
||||||
out["output"] = out_output[:self.main_model_single]
|
out["output"] = out_output[:self.main_model_single]
|
||||||
@ -130,9 +176,14 @@ class ControlNetFlux(Flux):
|
|||||||
patch_size = 2
|
patch_size = 2
|
||||||
if self.latent_input:
|
if self.latent_input:
|
||||||
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
|
||||||
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
elif self.mistoline:
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_cond_block(hint)
|
||||||
else:
|
else:
|
||||||
hint = hint * 2.0 - 1.0
|
hint = hint * 2.0 - 1.0
|
||||||
|
hint = self.input_hint_block(hint)
|
||||||
|
|
||||||
|
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|||||||
@ -842,6 +842,11 @@ class UNetModel(nn.Module):
|
|||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
|
if "emb_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["emb_patch"]
|
||||||
|
for p in patch:
|
||||||
|
emb = p(emb, self.model_channels, transformer_options)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
@ -495,7 +495,7 @@ def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
|||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if shift_model.device == device:
|
||||||
if shift_model not in keep_loaded:
|
if shift_model not in keep_loaded:
|
||||||
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
import os
|
|
||||||
|
|
||||||
from comfy.component_model.files import get_path_as_dict
|
from ..component_model.files import get_path_as_dict
|
||||||
|
|
||||||
|
|
||||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||||
|
|||||||
91
comfy_extras/nodes_lora_extract.py
Normal file
91
comfy_extras/nodes_lora_extract.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
|
def extract_lora(diff, rank):
|
||||||
|
conv2d = (len(diff.shape) == 4)
|
||||||
|
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||||
|
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||||
|
out_dim, in_dim = diff.size()[0:2]
|
||||||
|
rank = min(rank, in_dim, out_dim)
|
||||||
|
|
||||||
|
if conv2d:
|
||||||
|
if conv2d_3x3:
|
||||||
|
diff = diff.flatten(start_dim=1)
|
||||||
|
else:
|
||||||
|
diff = diff.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(diff.float())
|
||||||
|
U = U[:, :rank]
|
||||||
|
S = S[:rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:rank, :]
|
||||||
|
|
||||||
|
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||||
|
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||||
|
low_val = -hi_val
|
||||||
|
|
||||||
|
U = U.clamp(low_val, hi_val)
|
||||||
|
Vh = Vh.clamp(low_val, hi_val)
|
||||||
|
if conv2d:
|
||||||
|
U = U.reshape(out_dim, rank, 1, 1)
|
||||||
|
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||||
|
return (U, Vh)
|
||||||
|
|
||||||
|
class LoraSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||||
|
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {"model_diff": ("MODEL",),},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def save(self, filename_prefix, rank, model_diff=None):
|
||||||
|
if model_diff is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
|
||||||
|
output_sd = {}
|
||||||
|
prefix_key = "diffusion_model."
|
||||||
|
stored = set()
|
||||||
|
|
||||||
|
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||||
|
sd = model_diff.model_state_dict(filter_prefix=prefix_key)
|
||||||
|
|
||||||
|
for k in sd:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
weight_diff = sd[k]
|
||||||
|
if weight_diff.ndim < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
out = extract_lora(weight_diff, rank)
|
||||||
|
output_sd["{}.lora_up.weight".format(k[:-7])] = out[0].contiguous().half().cpu()
|
||||||
|
output_sd["{}.lora_down.weight".format(k[:-7])] = out[1].contiguous().half().cpu()
|
||||||
|
except:
|
||||||
|
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoraSave": LoraSave
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user