Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-09-05 09:23:00 -07:00
commit db423f8013
13 changed files with 266 additions and 55 deletions

View File

@ -87,6 +87,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()

View File

@ -110,6 +110,7 @@ class Configuration(dict):
force_channels_last (bool): Force channels last format when inferencing the models.
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
executor_factory (str): Either ThreadPoolExecutor or ProcessPoolExecutor, defaulting to ThreadPoolExecutor
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
"""
def __init__(self, **kwargs):
@ -184,6 +185,7 @@ class Configuration(dict):
self.max_queue_size: int = 65536
self.force_channels_last: bool = False
self.force_hf_local_dir_mode = False
self.preview_size: int = 512
# from guill
self.cache_lru: int = 0

View File

@ -100,6 +100,38 @@ def _cleanup():
class EmbeddedComfyClient:
"""
Embedded client for comfy executing prompts as a library.
This client manages a single-threaded executor to run long-running or blocking tasks
asynchronously without blocking the asyncio event loop. It initializes a PromptExecutor
in a dedicated thread for executing prompts and handling server-stub communications.
Example usage:
Asynchronous (non-blocking) usage with async-await:
```
# Write a workflow, or enable Dev Mode in the UI settings, then Save (API Format) to get the workflow in your
# workspace.
prompt_dict = {
"1": {"class_type": "KSamplerAdvanced", ...}
...
}
# Validate your workflow (the prompt)
from comfy.api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt_dict)
# Then use the client to run your workflow. This will start, then stop, a local ComfyUI workflow executor.
# It does not connect to a remote server.
async def main():
async with EmbeddedComfyClient() as client:
outputs = await client.queue_prompt(prompt)
print(outputs)
print("Now that we've exited the with statement, all your VRAM has been cleared from ComfyUI")
if __name__ == "__main__"
asyncio.run(main())
```
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
self._progress_handler = progress_handler or ServerStub()
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)

View File

@ -12,7 +12,7 @@ from .. import model_management
from .. import utils
import logging
MAX_PREVIEW_RESOLUTION = 512
MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):

View File

@ -200,6 +200,7 @@ async def main(from_script_dir: Optional[Path] = None):
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)

View File

@ -6,9 +6,10 @@ from enum import Enum
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image
from typing_extensions import NotRequired, TypedDict, runtime_checkable
from typing_extensions import NotRequired, TypedDict
from .queue_types import BinaryEventTypes
from ..cli_args_types import Configuration
from ..nodes.package_typing import InputTypeSpec
@ -201,4 +202,7 @@ class Executor(Protocol):
...
def shutdown(self, wait=True, *, cancel_futures=False):
...
...
ExecutePromptArgs = tuple[dict, str, str, dict, ExecutorToClientProgress | None, Configuration | None]

View File

@ -493,14 +493,15 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
return control
def load_controlnet_flux_xlabs(sd):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
@ -520,6 +521,11 @@ def load_controlnet_flux_instantx(sd):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
@ -580,13 +586,15 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data)
return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False

View File

@ -1,11 +1,16 @@
import concurrent.futures
from typing import Callable
from pebble import ProcessPool
from pebble import ProcessPool, ProcessFuture
from ..component_model.executor_types import Executor
from ..component_model.executor_types import Executor, ExecutePromptArgs
class ProcessPoolExecutor(ProcessPool, Executor):
def __init__(self, max_workers: int = 1):
super().__init__(max_workers=1)
def shutdown(self, wait=True, *, cancel_futures=False):
if cancel_futures:
raise NotImplementedError("cannot cancel futures in this implementation")
@ -15,5 +20,17 @@ class ProcessPoolExecutor(ProcessPool, Executor):
self.stop()
return
def schedule(self, function: Callable,
args: list = (),
kwargs: dict = {},
timeout: float = None) -> ProcessFuture:
try:
args: ExecutePromptArgs
prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
except ValueError:
pass
super().schedule(function, args, kwargs, timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)

View File

@ -1,32 +1,78 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
# modified to support different types of flux controlnets
import math
from typing_extensions import Never
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from typing_extensions import Never
from .layers import (timestep_embedding)
from .model import Flux
from .. import common_dit
class MistolineCondDownsamplBlock(nn.Module):
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
self.encoder = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x):
return self.encoder(x)
class MistolineControlnetBlock(nn.Module):
def __init__(self, hidden_size, dtype=None, device=None, operations=None):
super().__init__()
self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.linear(x))
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
self.mistoline = mistoline
# add ControlNet blocks
if self.mistoline:
control_block = lambda: MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
control_block = lambda: operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks.append(control_block())
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.controlnet_single_blocks.append(control_block())
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
@ -37,44 +83,44 @@ class ControlNetFlux(Flux):
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
if self.mistoline:
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
else:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor | list[Never] | None = None,
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor | list[Never] | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
if not self.latent_input:
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
@ -87,7 +133,7 @@ class ControlNetFlux(Flux):
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@ -104,13 +150,13 @@ class ControlNetFlux(Flux):
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1]:, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
@ -120,7 +166,7 @@ class ControlNetFlux(Flux):
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
@ -130,9 +176,14 @@ class ControlNetFlux(Flux):
patch_size = 2
if self.latent_input:
hint = common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
elif self.mistoline:
hint = hint * 2.0 - 1.0
hint = self.input_cond_block(hint)
else:
hint = hint * 2.0 - 1.0
hint = self.input_hint_block(hint)
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
bs, c, h, w = x.shape
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))

View File

@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

View File

@ -495,7 +495,7 @@ def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
for x in sorted(can_unload):

View File

@ -1,7 +1,6 @@
from comfy import sd1_clip
import os
from comfy.component_model.files import get_path_as_dict
from ..component_model.files import get_path_as_dict
class LongClipTokenizer_(sd1_clip.SDTokenizer):

View File

@ -0,0 +1,91 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
CLAMP_QUANTILE = 0.99
def extract_lora(diff, rank):
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
rank = min(rank, in_dim, out_dim)
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
U, S, Vh = torch.linalg.svd(diff.float())
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
class LoraSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 1024, "step": 1}),
},
"optional": {"model_diff": ("MODEL",),},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, model_diff=None):
if model_diff is None:
return {}
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
prefix_key = "diffusion_model."
stored = set()
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
sd = model_diff.model_state_dict(filter_prefix=prefix_key)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if weight_diff.ndim < 2:
continue
try:
out = extract_lora(weight_diff, rank)
output_sd["{}.lora_up.weight".format(k[:-7])] = out[0].contiguous().half().cpu()
output_sd["{}.lora_down.weight".format(k[:-7])] = out[1].contiguous().half().cpu()
except:
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}