Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-12-09 15:54:37 -08:00
commit 2d1676c717
118 changed files with 66956 additions and 3748 deletions

View File

@ -1,2 +1,22 @@
* @comfyanonymous
* @doctorpangloss
# Admins
* @doctorpangloss
# Note: Github teams syntax cannot be used here as the repo is not owned by Comfy-Org.
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

View File

@ -236,6 +236,12 @@ export HSA_OVERRIDE_GFX_VERSION=10.3.0
comfyui
```
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
### Model Downloading
ComfyUI LTS supports downloading models on demand. Its list of known models includes the most notable and common Stable Diffusion architecture checkpoints, slider LoRAs, all the notable ControlNets for SD1.5 and SDXL, and a small selection of LLM models. Additionally, all other supported LTS nodes will download models using the same mechanisms. This means that you will save storage space and time: you won't have to ever figure out the "right name" for a model, where to download it from, or where to put it ever again.
@ -715,37 +721,37 @@ The default installation includes a fast latent preview method that's low-resolu
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
`Ctrl` can also be replaced with `Cmd` instead for macOS users
# Configuration

View File

@ -14,7 +14,6 @@ class InternalRoutes:
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self, prompt_server):

View File

@ -39,7 +39,7 @@ class UserManager():
self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

View File

@ -2,11 +2,9 @@
#and modified
import torch
import torch as th
import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import (
zero_module,
timestep_embedding,
)

120
comfy/cldm/dit_embedder.py Normal file
View File

@ -0,0 +1,120 @@
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@ -1,6 +1,6 @@
import torch
from typing import Dict, Optional, List
from typing import Optional, List
import torch
from einops import einops
from torch import Tensor
@ -12,15 +12,16 @@ def default(x, y):
return x
return y
class ControlNet(MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
**kwargs,
self,
num_blocks=None,
control_latent_channels=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
# controlnet_blocks
@ -44,15 +45,15 @@ class ControlNet(MMDiT):
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Dict[str, List[Tensor]]:
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint=None,
) -> dict[str, List[Tensor]]:
#weird sd3 controlnet specific stuff
# weird sd3 controlnet specific stuff
y = torch.zeros_like(y)
if self.context_processor is not None:

View File

@ -58,9 +58,11 @@ def _create_parser() -> EnhancedConfigArgParser:
fp_group.add_argument("--force-bf16", action="store_true", help="Force bf16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true",
help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

View File

@ -60,8 +60,10 @@ class Configuration(dict):
force_bf16 (bool): Force using BF16 precision.
bf16_unet (bool): Use BF16 precision for UNet.
fp16_unet (bool): Use FP16 precision for UNet.
fp8_e4m3fn_unet (bool): Use FP8 precision (e4m3fn variant) for UNet.
fp8_e5m2_unet (bool): Use FP8 precision (e5m2 variant) for UNet.
fp32_unet (bool): Run the diffusion model in fp32.
fp64_unet (bool): Run the diffusion model in fp64.
fp8_e4m3fn_unet (bool): Store unet weights in fp8_e4m3fn
fp8_e5m2_unet (bool): Store unet weights in fp8_e5m2.
fp16_vae (bool): Run the VAE in FP16 precision.
fp32_vae (bool): Run the VAE in FP32 precision.
bf16_vae (bool): Run the VAE in BF16 precision.
@ -145,6 +147,8 @@ class Configuration(dict):
self.force_bf16: bool = False
self.bf16_unet: bool = False
self.fp16_unet: bool = False
self.fp32_unet: bool = False
self.fp64_unet: bool = False
self.fp8_e4m3fn_unet: bool = False
self.fp8_e5m2_unet: bool = False
self.fp16_vae: bool = False

View File

@ -20,13 +20,18 @@ class Output:
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
if crop:
scale = (size / min(image.shape[2], image.shape[3]))
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size) // 2
w = (image.shape[3] - size) // 2
image = image[:, :, h:h + size, w:w + size]
@ -64,9 +69,9 @@ class ClipVisionModel():
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
def encode_image(self, image, crop=True):
load_models_gpu([self.patcher])
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()

View File

@ -39,6 +39,7 @@ nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from ..validation import validate_node_input
class IsChangedCache:
@ -699,7 +700,8 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
received_type = r[val[1]]
received_types[x] = received_type
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input) and not any_enum:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",

View File

@ -21,8 +21,8 @@ MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image) -> Image:
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=model_management.device_supports_non_blocking(latent_image.device))
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8, non_blocking=model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy())
@ -78,7 +78,7 @@ def get_previewer(device, latent_format):
if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES)
@ -98,6 +98,7 @@ def get_previewer(device, latent_format):
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
@ -115,5 +116,5 @@ def prepare_callback(model, steps, x0_output_dict=None):
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
return callback

View File

@ -13,19 +13,17 @@ import sys
import traceback
import urllib
import uuid
from asyncio import Future, AbstractEventLoop, Task
from asyncio import AbstractEventLoop, Task
from enum import Enum
from io import BytesIO
from posixpath import join as urljoin
from typing import List, Optional
from urllib.parse import quote, urlencode
from urllib.parse import quote
import aiofiles
import aiohttp
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from .latent_preview_image_encoding import encode_preview_image
@ -37,17 +35,13 @@ from ..api_server.routes.internal.internal_routes import InternalRoutes
from ..app.frontend_management import FrontendManager
from ..app.user_manager import UserManager
from ..cli_args import args
from ..client.client_types import FileOutput
from ..cmd import execution
from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, StatusMessage, QueueInfo, ExecInfo
from ..component_model.file_output_path import file_output_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus
from ..digest import digest
from ..component_model.queue_types import QueueItem, BinaryEventTypes
from ..images import open_image
from ..model_filemanager import download_model, DownloadModelStatus
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes
@ -264,11 +258,11 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
extensions = list(map(lambda f: "/" + os.path.relpath(str(f), self.web_root).replace("\\", "/"), files))
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
extensions.extend(list(map(lambda f: "/extensions/" + quote(name) + "/" + os.path.relpath(str(f), dir).replace("\\", "/"), files)))
return web.json_response(extensions)
@ -688,248 +682,6 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200)
# Internal route. Should not be depended upon and is subject to change at any time.
# TODO(robinhuang): Move to internal route table class once we refactor PromptServer to pass around Websocket.
# NOTE: This was an experiment and WILL BE REMOVED
@routes.post("/internal/models/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadModelStatus):
payload = status.to_dict()
payload['download_path'] = filename
await self.send_json("download_progress", payload)
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
folder_path = data.get('folder_path')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename or not folder_path:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
session = self.client_session
if session is None:
logger.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
await task
return web.json_response(task.result().to_dict())
@routes.get("/api/v1/prompts/{prompt_id}")
async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse:
prompt_id: str = request.match_info.get("prompt_id", "")
if prompt_id == "":
return web.json_response(status=404)
history_items = self.prompt_queue.get_history(prompt_id)
if len(history_items) == 0 or prompt_id not in history_items:
# todo: this should really be moved to a stateful queue abstraction
if prompt_id in self.background_tasks:
return web.json_response(status=204)
else:
# todo: this should check a stateful queue abstraction
return web.json_response(status=404)
elif prompt_id in history_items:
history_entry = history_items[prompt_id]
return web.json_response(history_entry["outputs"])
else:
return web.json_response(status=500)
@routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
accept = request.headers.get("accept", "application/json")
if accept == '*/*':
accept = "application/json"
content_type = request.headers.get("content-type", "application/json")
preferences = request.headers.get("prefer", "") + request.query.get("prefer", "") + " " + content_type
if "+" in content_type:
content_type = content_type.split("+")[0]
wait = not "respond-async" in preferences
if accept not in ("application/json", "image/png"):
return web.json_response(status=400, reason=f"invalid accept content type, expected application/json or image/png, got {accept}")
# check if the queue is too long
queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
if queue_size > queue_too_busy_size:
return web.json_response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
# read the request
prompt_dict: dict = {}
if content_type == 'application/json':
prompt_dict = await request.json()
elif content_type == 'multipart/form-data':
try:
reader = await request.multipart()
async for part in reader:
if part is None:
break
if part.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json':
prompt_dict = await part.json()
if 'prompt' in prompt_dict:
prompt_dict = prompt_dict['prompt']
elif part.filename:
file_data = await part.read(decode=True)
# overwrite existing files
upload_dir = PromptServer.get_upload_dir()
async with aiofiles.open(os.path.join(upload_dir, part.filename), mode='wb') as file:
await file.write(file_data)
except IOError as ioError:
return web.Response(status=507, reason=str(ioError))
except MemoryError as memoryError:
return web.Response(status=507, reason=str(memoryError))
except Exception as ex:
return web.Response(status=400, reason=str(ex))
if len(prompt_dict) == 0:
return web.Response(status=400, reason="no prompt was specified")
content_digest = digest(prompt_dict)
valid = execution.validate_prompt(prompt_dict)
if not valid[0]:
return web.Response(status=400, content_type="application/json", body=json.dumps(valid[1]))
# convert a valid prompt to the queue tuple this expects
number = self.number
self.number += 1
result: TaskInvocation
completed: Future[TaskInvocation | dict] = self.loop.create_future()
# todo: actually implement idempotency keys
# we would need some kind of more durable, distributed task queue
task_id = str(uuid.uuid4())
item = QueueItem(queue_tuple=(number, task_id, prompt_dict, {}, valid[2]), completed=completed)
try:
if hasattr(self.prompt_queue, "put_async") or isinstance(self.prompt_queue, AsyncAbstractPromptQueue):
# this enables span propagation seamlessly
fut = self.prompt_queue.put_async(item)
if wait:
result = await fut
if result is None:
return web.Response(body="the queue is shutting down", status=503)
else:
return await self._schedule_background_task_with_web_response(fut, task_id)
else:
self.prompt_queue.put(item)
if wait:
await completed
else:
return await self._schedule_background_task_with_web_response(completed, task_id)
task_invocation_or_dict: TaskInvocation | dict = completed.result()
if isinstance(task_invocation_or_dict, dict):
result = TaskInvocation(item_id=item.prompt_id, outputs=task_invocation_or_dict, status=ExecutionStatus("success", True, []))
else:
result = task_invocation_or_dict
except ExecutionError as exec_exc:
result = exec_exc.as_task_invocation()
except Exception as ex:
return web.Response(body=str(ex), status=500)
if result.status is not None and result.status.status_str == "error":
return web.Response(body=json.dumps(result.status._asdict()), status=500, content_type="application/json")
# find images and read them
output_images: List[FileOutput] = []
for node_id, node in result.outputs.items():
images: List[FileOutput] = []
if 'images' in node:
images = node['images']
# todo: does this ever occur?
elif (isinstance(node, dict)
and 'ui' in node and isinstance(node['ui'], dict)
and 'images' in node['ui']):
images = node['ui']['images']
for image_tuple in images:
output_images.append(image_tuple)
if len(output_images) > 0:
main_image = output_images[0]
filename = main_image["filename"]
digest_headers_ = {
"Digest": f"SHA-256={content_digest}",
}
urls_ = []
if len(output_images) == 1:
digest_headers_.update({
"Content-Disposition": f"filename=\"{filename}\""
})
for image_indv_ in output_images:
local_address = f"http://{self.address}:{self.port}"
external_address = self.external_address
for base in (local_address, external_address):
try:
url: URL = urlparse(urljoin(base, "view"))
except ValueError:
continue
url_search_dict: FileOutput = dict(image_indv_)
del url_search_dict["abs_path"]
if "name" in url_search_dict:
del url_search_dict["name"]
if url_search_dict["subfolder"] == "":
del url_search_dict["subfolder"]
url.search = f"?{urlencode(url_search_dict)}"
urls_.append(str(url))
if accept == "application/json":
return web.Response(status=200,
content_type="application/json",
headers=digest_headers_,
body=json.dumps({
'urls': urls_,
'outputs': result.outputs
}))
elif accept == "image/png" or accept == "image/jpeg":
return web.FileResponse(main_image["abs_path"],
headers=digest_headers_)
else:
return web.Response(status=500,
reason="unreachable")
else:
return web.Response(status=204)
@routes.get("/api/v1/prompts")
async def get_api_prompt(_: web.Request) -> web.Response:
history = self.prompt_queue.get_history()
history_items = list(history.values())
if len(history_items) == 0:
return web.Response(status=404)
last_history_item: HistoryEntry = history_items[-1]
prompt = last_history_item['prompt'][2]
return web.json_response(prompt, status=200)
async def _schedule_background_task_with_web_response(self, fut, task_id):
task = asyncio.create_task(fut, name=task_id)
self.background_tasks[task_id] = task
task.add_done_callback(lambda _: self.background_tasks.pop(task_id))
# todo: type this from the OpenAPI spec
return web.json_response({
"prompt_id": task_id
}, status=202, headers={
"Location": f"api/v1/prompts/{task_id}",
"Retry-After": "60"
})
@property
def external_address(self):
return self._external_address if self._external_address is not None else f"http://{'localhost' if self.address == '0.0.0.0' else self.address}:{self.port}"
@external_address.setter
def external_address(self, value):
self._external_address = value
@property
def receive_all_progress_notifications(self) -> bool:
return True
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)

View File

@ -0,0 +1,43 @@
# Comfy Typing
## Type hinting for ComfyUI Node development
This module provides type hinting and concrete convenience types for node developers.
If cloned to the custom_nodes directory of ComfyUI, types can be imported using:
```python
from comfy_types import IO, ComfyNodeABC, CheckLazyMixin
class ExampleNode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {"required": {}}
```
Full example is in [examples/example_nodes.py](examples/example_nodes.py).
# Types
A few primary types are documented below. More complete information is available via the docstrings on each type.
## `IO`
A string enum of built-in and a few custom data types. Includes the following special types and their requisite plumbing:
- `ANY`: `"*"`
- `NUMBER`: `"FLOAT,INT"`
- `PRIMITIVE`: `"STRING,FLOAT,INT,BOOLEAN"`
## `ComfyNodeABC`
An abstract base class for nodes, offering type-hinting / autocomplete, and somewhat-alright docstrings.
### Type hinting for `INPUT_TYPES`
![INPUT_TYPES auto-completion in Visual Studio Code](examples/input_types.png)
### `INPUT_TYPES` return dict
![INPUT_TYPES return value type hinting in Visual Studio Code](examples/required_hint.png)
### Options for individual inputs
![INPUT_TYPES return value option auto-completion in Visual Studio Code](examples/input_options.png)

View File

@ -1,5 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
class UnetApplyFunction(Protocol):
@ -30,3 +31,15 @@ class UnetParams(TypedDict):
UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
__all__ = [
"UnetWrapperFunction",
UnetApplyConds.__name__,
UnetParams.__name__,
UnetApplyFunction.__name__,
IO.__name__,
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
]

View File

@ -0,0 +1,28 @@
from comfy_types import IO, ComfyNodeABC, InputTypeDict
from inspect import cleandoc
class ExampleNode(ComfyNodeABC):
"""An example node that just adds 1 to an input integer.
* Requires an IDE configured with analysis paths etc to be worth looking at.
* Not intended for use in ComfyUI.
"""
DESCRIPTION = cleandoc(__doc__)
CATEGORY = "examples"
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"input_int": (IO.INT, {"defaultInput": True}),
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("input_plus_one",)
FUNCTION = "execute"
def execute(self, input_int: int):
return (input_int + 1,)

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

View File

@ -0,0 +1,274 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict
from abc import ABC, abstractmethod
from enum import Enum
class StrEnum(str, Enum):
"""Base class for string enums. Python's StrEnum is not available until 3.11."""
def __str__(self) -> str:
return self.value
class IO(StrEnum):
"""Node input/output data types.
Includes functionality for ``"*"`` (`ANY`) and ``"MULTI,TYPES"``.
"""
STRING = "STRING"
IMAGE = "IMAGE"
MASK = "MASK"
LATENT = "LATENT"
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
GUIDER = "GUIDER"
NOISE = "NOISE"
CLIP = "CLIP"
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
GLIGEN = "GLIGEN"
UPSCALE_MODEL = "UPSCALE_MODEL"
AUDIO = "AUDIO"
WEBCAM = "WEBCAM"
POINT = "POINT"
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
ANY = "*"
"""Always matches any type, but at a price.
Causes some functionality issues (e.g. reroutes, link types), and should be avoided whenever possible.
"""
NUMBER = "FLOAT,INT"
"""A float or an int - could be either"""
PRIMITIVE = "STRING,FLOAT,INT,BOOLEAN"
"""Could be any of: string, float, int, or bool"""
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
"""
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", <outputIndex>]`). Designed for node expansion."""
tooltip: str
"""Tooltip for the input (or widget), shown on pointer hover"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
min: float
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
max: float
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
step: float
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
round: float
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
label_on: str
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
label_on: str
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
multiline: bool
"""Use a multiline text box (``STRING``)"""
placeholder: str
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
dynamicPrompts: bool
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
node_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
unique_id: Literal["UNIQUE_ID"]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
prompt: Literal["PROMPT"]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
extra_pnginfo: Literal["EXTRA_PNGINFO"]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
dynprompt: Literal["DYNPROMPT"]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
"""
required: dict[str, tuple[IO, InputTypeOptions]]
"""Describes all inputs that must be connected for the node to execute."""
optional: dict[str, tuple[IO, InputTypeOptions]]
"""Describes inputs which do not need to be connected."""
hidden: HiddenInputTypeDict
"""Offers advanced functionality and server-client communication.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
"""
DESCRIPTION: str
"""Node description, shown as a tooltip when hovering over the node.
Usage::
# Explicitly define the description
DESCRIPTION = "Example description here."
# Use the docstring of the node class.
DESCRIPTION = cleandoc(__doc__)
"""
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
@classmethod
@abstractmethod
def INPUT_TYPES(s) -> InputTypeDict:
"""Defines node inputs.
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
* The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
"""
return {"required": {}}
OUTPUT_NODE: bool
"""Flags this node as an output node, causing any inputs it requires to be executed.
If a node is not connected to any output nodes, that node will not be executed. Usage::
OUTPUT_NODE = True
From the docs:
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``.
From the docs:
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
OUTPUT_IS_LIST: tuple[bool]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
A ``tuple[bool]``, where the items match those in `RETURN_TYPES`::
RETURN_TYPES = (IO.INT, IO.INT, IO.STRING)
OUTPUT_IS_LIST = (True, True, False) # The string output will be handled normally
From the docs:
In order to tell Comfy that the list being returned should not be wrapped, but treated as a series of data for sequential processing,
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
"""
RETURN_TYPES: tuple[IO]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
"""
RETURN_NAMES: tuple[str]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
"""
OUTPUT_TOOLTIPS: tuple[str]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
"""
class CheckLazyMixin:
"""Provides a basic check_lazy_status implementation and type hinting for nodes that use lazy inputs."""
def check_lazy_status(self, **kwargs) -> list[str]:
"""Returns a list of input names that should be evaluated.
This basic mixin impl. requires all inputs.
:kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \
When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``.
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need

View File

@ -20,6 +20,7 @@ import logging
import math
import os
from enum import Enum
from typing import TYPE_CHECKING
import torch
@ -30,11 +31,15 @@ from . import model_patcher
from . import ops
from . import utils
from .cldm import cldm, mmdit
from .cldm.dit_embedder import ControlNetEmbedder
from .ldm.cascade import controlnet as cascade_controlnet
from .ldm.flux import controlnet as controlnet_flux
from .ldm.hydit.controlnet import HunYuanControlNet
from .t2i_adapter import adapter
if TYPE_CHECKING:
from .hooks import HookGroup
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@ -79,6 +84,8 @@ class ControlBase:
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@ -116,6 +123,14 @@ class ControlBase:
out += self.previous_controlnet.get_models()
return out
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
out.append(self.extra_hooks)
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_extra_hooks()
return out
def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
@ -130,6 +145,8 @@ class ControlBase:
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -182,7 +199,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, ckpt_name: str = None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a, ckpt_name: str = None):
super().__init__()
self.control_model = control_model
self.load_device = load_device
@ -198,11 +215,12 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
@ -226,6 +244,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@ -458,6 +477,82 @@ def load_controlnet_mmdit(sd, model_options=None):
return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
new_sd = {}
for k in utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = model_management.get_torch_device()
offload_device = model_management.unet_offload_device()
unet_dtype = model_management.unet_dtype(model_params=-1)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options=None):
if model_options is None:
model_options = {}
@ -636,7 +731,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) # Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
@ -756,10 +854,10 @@ class T2IAdapter(ControlBase):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height
def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:

View File

@ -27,5 +27,8 @@ def load_extra_path_config(yaml_path):
full_path = y
if base_path is not None:
full_path = os.path.join(base_path, full_path)
elif not os.path.isabs(full_path):
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
full_path = os.path.abspath(os.path.join(yaml_dir, y))
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path, is_default=is_default)

View File

@ -1,10 +1,9 @@
#code taken from: https://github.com/wl-zhao/UniPC and modified
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange, tqdm
from tqdm.auto import trange
class NoiseScheduleVP:

690
comfy/hooks.py Normal file
View File

@ -0,0 +1,690 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable
import enum
import math
import torch
import numpy as np
import itertools
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher, PatcherInjection
from comfy.model_base import BaseModel
from comfy.sd import CLIP
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from node_helpers import conditioning_set_values
class EnumHookMode(enum.Enum):
MinVram = "minvram"
MaxSpeed = "maxspeed"
class EnumHookType(enum.Enum):
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"
class _HookRef:
pass
# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return True
class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
@property
def strength(self):
return self.hook_keyframe.strength
def initialize_timesteps(self, model: 'BaseModel'):
self.reset()
self.hook_keyframe.initialize_timesteps(model)
def reset(self):
self.hook_keyframe.reset()
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: Hook = subtype()
c.hook_type = self.hook_type
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c
def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")
def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
pass
def __eq__(self, other: 'Hook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class WeightHook(Hook):
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
@property
def strength_model(self):
return self._strength_model * self.strength
@property
def strength_clip(self):
return self._strength_clip * self.strength
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:
strength = self._strength_clip
if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
weights = self.weights_clip
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
# TODO: add logs about any keys that were not applied
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WeightHook = super().clone(subtype)
c.weights = self.weights
c.weights_clip = self.weights_clip
c.need_weight_init = self.need_weight_init
c._strength_model = self._strength_model
c._strength_clip = self._strength_clip
return c
class PatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.Patch)
self.patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: PatchHook = super().clone(subtype)
c.patches = self.patches
return c
# TODO: add functionality
class ObjectPatchHook(Hook):
def __init__(self):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches: dict = None
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: ObjectPatchHook = super().clone(subtype)
c.object_patches = self.object_patches
return c
# TODO: add functionality
class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
self.append_when_same = True
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: AddModelsHook = super().clone(subtype)
c.key = self.key
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality
class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
class WrapperHook(Hook):
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: WrapperHook = super().clone(subtype)
c.wrappers_dict = self.wrappers_dict
return c
def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True
class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
self.key = key
self.injections = injections
def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: SetInjectionsHook = super().clone(subtype)
c.key = self.key
c.injections = self.injections.copy() if self.injections else self.injections
return c
def add_hook_injections(self, model: 'ModelPatcher'):
# TODO: add functionality
pass
class HookGroup:
def __init__(self):
self.hooks: list[Hook] = []
def add(self, hook: Hook):
if hook not in self.hooks:
self.hooks.append(hook)
def contains(self, hook: Hook):
return hook in self.hooks
def clone(self):
c = HookGroup()
for hook in self.hooks:
c.add(hook.clone())
return c
def clone_and_combine(self, other: 'HookGroup'):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c
def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.hook_keyframe = hook_kf
def get_dict_repr(self):
d: dict[EnumHookType, dict[Hook, None]] = {}
for hook in self.hooks:
with_type = d.setdefault(hook.hook_type, {})
with_type[hook] = None
return d
def get_hooks_for_clip_schedule(self):
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {}
for hook in self.hooks:
# only care about WeightHooks, for now
if hook.hook_type == EnumHookType.Weight:
hook_schedule = []
# if no hook keyframes, assign default value
if len(hook.hook_keyframe.keyframes) == 0:
hook_schedule.append(((0.0, 1.0), None))
scheduled_hooks[hook] = hook_schedule
continue
# find ranges of values
prev_keyframe = hook.hook_keyframe.keyframes[0]
for keyframe in hook.hook_keyframe.keyframes:
if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength):
hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe))
prev_keyframe = keyframe
elif keyframe.start_percent == prev_keyframe.start_percent:
prev_keyframe = keyframe
# create final range, assuming last start_percent was not 1.0
if not math.isclose(prev_keyframe.start_percent, 1.0):
hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe))
scheduled_hooks[hook] = hook_schedule
# hooks should not have their schedules in a list of tuples
all_ranges: list[tuple[float, float]] = []
for range_kfs in scheduled_hooks.values():
for t_range, keyframe in range_kfs:
all_ranges.append(t_range)
# turn list of ranges into boundaries
boundaries_set = set(itertools.chain.from_iterable(all_ranges))
boundaries_set.add(0.0)
boundaries = sorted(boundaries_set)
real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)]
# with real ranges defined, give appropriate hooks w/ keyframes for each range
scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = []
for t_range in real_ranges:
hooks_schedule = []
for hook, val in scheduled_hooks.items():
keyframe = None
# check if is a keyframe that works for the current t_range
for stored_range, stored_kf in val:
# if stored start is less than current end, then fits - give it assigned keyframe
if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]:
keyframe = stored_kf
break
hooks_schedule.append((hook, keyframe))
scheduled_keyframes.append((t_range, hooks_schedule))
return scheduled_keyframes
def reset(self):
for hook in self.hooks:
hook.reset()
@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.")
# if no hooks, then return None
if len(actual) == 0:
return None
# if only 1 hook, just return itself without cloning
elif len(actual) == 1:
return actual[0]
final_hook: HookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class HookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = HookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class HookKeyframeGroup:
def __init__(self):
self.keyframes: list[HookKeyframe] = []
self._current_keyframe: HookKeyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self._curr_t = -1.
# properties shadow those of HookWeightsKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._current_strength = None
self.curr_t = -1.
self._set_first_as_current()
def add(self, keyframe: HookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int):
return index >= 0 and index < len(self.keyframes)
def is_empty(self):
return len(self.keyframes) == 0
def clone(self):
c = HookKeyframeGroup()
for keyframe in self.keyframes:
c.keyframes.append(keyframe.clone())
c._set_first_as_current()
return c
def initialize_timesteps(self, model: 'BaseModel'):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
prev_strength = self._current_strength
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need to switch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_strength = eval_c.strength
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index and prev_strength != self._current_strength
class InterpolationMethod:
LINEAR = "linear"
EASE_IN = "ease_in"
EASE_OUT = "ease_out"
EASE_IN_OUT = "ease_in_out"
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
@classmethod
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
diff = num_to - num_from
if method == cls.LINEAR:
weights = torch.linspace(num_from, num_to, length)
elif method == cls.EASE_IN:
index = torch.linspace(0, 1, length)
weights = diff * np.power(index, 2) + num_from
elif method == cls.EASE_OUT:
index = torch.linspace(0, 1, length)
weights = diff * (1 - np.power(1 - index, 2)) + num_from
elif method == cls.EASE_IN_OUT:
index = torch.linspace(0, 1, length)
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
else:
raise ValueError(f"Unrecognized interpolation method '{method}'.")
if reverse:
weights = weights.flip(dims=(0,))
return weights
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
if not objects:
return objects
elif len(objects) <= 1:
return [x for x in objects]
# now that we know we have to sort, do it following these rules:
# a) if objects have same value of attribute, maintain their relative order
# b) perform sorting of the groups of objects with same attributes
unique_attrs = {}
for o in objects:
val_attr = getattr(o, attr)
attr_list: list = unique_attrs.get(val_attr, list())
attr_list.append(o)
if val_attr not in unique_attrs:
unique_attrs[val_attr] = attr_list
# now that we have the unique attr values grouped together in relative order, sort them by key
sorted_attrs = dict(sorted(unique_attrs.items()))
# now flatten out the dict into a list to return
sorted_list = []
for object_list in sorted_attrs.values():
sorted_list.extend(object_list)
return sorted_list
def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
hook.weights = lora
return hook_group
def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
hook_group.add(hook)
patches_model = None
patches_clip = None
if weights_model is not None:
patches_model = {}
for key in weights_model:
patches_model[key] = ("model_as_lora", (weights_model[key],))
if weights_clip is not None:
patches_clip = {}
for key in weights_clip:
patches_clip[key] = ("model_as_lora", (weights_clip[key],))
hook.weights = patches_model
hook.weights_clip = patches_clip
hook.need_weight_init = False
return hook_group
def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
if discard_model_sampling:
# do not include ANY model_sampling components of the model that should act as a patch
for key in list(patches_model.keys()):
if key.startswith("model_sampling"):
patches_model.pop(key, None)
return patches_model
# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
hook_group = HookGroup()
hook = WeightHook()
hook_group.add(hook)
loaded: dict[str] = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
print(f"NOT LOADED {x}")
return (new_modelpatcher, new_clip, hook_group)
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
hooks_key = 'hooks'
# if hooks only exist in one dict, do what's needed so that it ends up in c_dict
if hooks_key not in values:
return
if hooks_key not in c_dict:
hooks_value = values.get(hooks_key, None)
if hooks_value is not None:
c_dict[hooks_key] = hooks_value
return
# otherwise, need to combine with minimum duplication via cache
hooks_tuple = (c_dict[hooks_key], values[hooks_key])
cached_hooks = cache.get(hooks_tuple, None)
if cached_hooks is None:
new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1])
cache[hooks_tuple] = new_hooks
c_dict[hooks_key] = new_hooks
else:
c_dict[hooks_key] = cache[hooks_tuple]
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
c = []
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
if append_hooks and k == 'hooks':
_combine_hooks_from_values(n[1], values, hooks_combine_cache)
else:
n[1][k] = values[k]
c.append(n)
return c
def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True):
if hooks is None:
return cond
return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks)
def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]):
if timestep_range is None:
return cond
return conditioning_set_values(cond, {"start_percent": timestep_range[0],
"end_percent": timestep_range[1]})
def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float):
if mask is None:
return cond
set_area_to_bounds = False
if set_cond_area != 'default':
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(cond, {'mask': mask,
'set_area_to_bounds': set_area_to_bounds,
'mask_strength': strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def combine_with_new_conds(conds: list, new_conds: list):
combined_conds = []
for c, new_c in zip(conds, new_conds):
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
def set_conds_props(conds: list, strength: float, set_cond_area: str,
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
final_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks)
# next, apply mask to conditioning
c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range)
# finally, apply mask to conditioning and store
final_conds.append(c)
return final_conds
def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_default_conds_and_combine(conds: list, new_conds: list,
hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {'default': True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds

View File

@ -176,12 +176,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if sigma_down == 0:
x = denoised
else:
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
@ -193,19 +195,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
if sigmas[i + 1] == 0:
x = denoised
else:
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
sigma_down = sigmas[i + 1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i + 1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
# Euler method
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
if eta > 0:
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
@torch.no_grad()
@ -281,6 +286,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -307,6 +315,38 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:

View File

@ -233,3 +233,136 @@ class Mochi(LatentFormat):
class LTXV(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]

View File

@ -2,7 +2,7 @@
import torch
from torch import nn
from typing import Literal, Dict, Any
from typing import Literal
import math
from ... import ops
ops = ops.disable_weight_init

View File

@ -2,8 +2,8 @@
import torch
import torch.nn as nn
from torch import Tensor, einsum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from torch import Tensor
from typing import List, Union
from einops import rearrange
import math
from ... import ops

View File

@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
import torchvision
from torch import nn
from .common import LayerNorm2d_op

View File

@ -4,7 +4,7 @@ from comfy import ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]

View File

@ -1,7 +1,7 @@
#original code from https://github.com/genmoai/models under apache 2.0 license
#adapted to ComfyUI
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn as nn

View File

@ -1,8 +1,6 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ... import ops
from ..modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
@ -287,7 +285,7 @@ class HunYuanDiT(nn.Module):
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
@ -315,8 +313,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@ -364,6 +361,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None
if control:
controls = control.get("output", None)
@ -375,9 +374,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

View File

@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..modules.attention import optimized_attention
from ... import ops

View File

@ -380,6 +380,7 @@ class LTXVModel(torch.nn.Module):
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
@ -416,13 +417,15 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
@ -430,10 +433,22 @@ class LTXVModel(torch.nn.Module):
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = 0.0
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
@ -469,14 +484,24 @@ class LTXVModel(torch.nn.Module):
batch_size, -1, x.shape[-1]
)
for block in self.transformer_blocks:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (

View File

@ -3,7 +3,7 @@ from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Any, Mapping, Optional, Tuple, Union, List
from typing import Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm

View File

@ -1,6 +1,5 @@
from typing import Tuple, Union
import torch
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d

View File

@ -1,7 +1,7 @@
import torch
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
import logging as logpy
from ..modules.distributions.distributions import DiagonalGaussianDistribution

View File

@ -1,16 +1,16 @@
import logging
import math
from functools import partial
from typing import Dict, Optional, List
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from ..attention import optimized_attention
from einops import rearrange, repeat
from .util import timestep_embedding
from .... import ops
from ..attention import optimized_attention
from ... import common_dit
from .... import ops
def default(x, y):
if x is not None:

View File

@ -3,7 +3,6 @@ import math
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Any
import logging
from .... import model_management

View File

@ -1,23 +1,25 @@
import logging
from abc import abstractmethod
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import logging
from .util import (
checkpoint,
avg_pool_nd,
zero_module,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists
from .... import ops
from .... import patcher_extension
ops = ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
@ -29,7 +31,8 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
# This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
@ -47,9 +50,19 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
@ -59,6 +72,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, *args, **kwargs):
return forward_timestep_embed(self, *args, **kwargs)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
@ -95,6 +109,7 @@ class Upsample(nn.Module):
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
@ -141,23 +156,23 @@ class ResBlock(TimestepBlock):
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=ops
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=ops
):
super().__init__()
self.channels = channels
@ -231,7 +246,6 @@ class ResBlock(TimestepBlock):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@ -266,23 +280,23 @@ class ResBlock(TimestepBlock):
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=ops
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=ops
):
super().__init__(
channels,
@ -324,11 +338,11 @@ class VideoResBlock(ResBlock):
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator=None,
) -> th.Tensor:
x = super().forward(x, emb)
@ -353,6 +367,7 @@ class Timestep(nn.Module):
def forward(self, t):
return timestep_embedding(t, self.dim)
def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
@ -363,6 +378,7 @@ def apply_control(h, control, name):
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
@ -390,50 +406,50 @@ class UNetModel(nn.Module):
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
dtype=th.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
dtype=th.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
):
super().__init__()
@ -525,13 +541,13 @@ class UNetModel(nn.Module):
ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
@ -556,27 +572,27 @@ class UNetModel(nn.Module):
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=ops
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
@ -640,7 +656,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
@ -649,8 +665,8 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@ -692,7 +708,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [
get_resblock(
@ -715,24 +731,24 @@ class UNetModel(nn.Module):
if transformer_depth_middle >= -1:
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
@ -766,7 +782,7 @@ class UNetModel(nn.Module):
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
@ -813,12 +829,19 @@ class UNetModel(nn.Module):
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
return patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, y, control, transformer_options, **kwargs)
def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@ -836,7 +859,7 @@ class UNetModel(nn.Module):
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
@ -872,7 +895,6 @@ class UNetModel(nn.Module):
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()

View File

@ -1,10 +1,10 @@
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from .util import extract_into_tensor, make_beta_schedule
from ...util import default
class AbstractLowScaleModel(nn.Module):
@ -80,6 +80,3 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level

View File

@ -8,7 +8,6 @@
# thanks!
import os
import math
import torch
import torch.nn as nn

View File

@ -1,5 +1,5 @@
import functools
from typing import Callable, Iterable, Union
from typing import Iterable, Union
import torch
from einops import rearrange, repeat

View File

@ -36,7 +36,7 @@ LORA_CLIP_MAP = {
}
def load_lora(lora, to_load) -> PatchDict:
def load_lora(lora, to_load, log_missing=True) -> PatchDict:
patch_dict: PatchDict = {}
loaded_keys = set()
for x in to_load:
@ -65,6 +65,7 @@ def load_lora(lora, to_load) -> PatchDict:
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = B_name = None
@ -82,6 +83,9 @@ def load_lora(lora, to_load) -> PatchDict:
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name = "{}.lora_linear_layer.down.weight".format(x)
@ -206,9 +210,10 @@ def load_lora(lora, to_load) -> PatchDict:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return patch_dict
@ -307,7 +312,7 @@ def model_lora_keys_unet(model, key_map=None):
key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
key_map["{}".format(k)] = k # generic lora format for not .weight without any weird key names
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@ -364,6 +369,12 @@ def model_lora_keys_unet(model, key_map=None):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # onetrainer
if isinstance(model, model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map
@ -422,7 +433,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32):
def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
v = p[1]
@ -465,6 +476,11 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "model_as_lora":
target_weight: torch.Tensor = v[0]
diff_weight = model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": # lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype)

View File

@ -34,14 +34,16 @@ from .ldm.aura.mmdit import MMDiT as AuraMMDiT
from .ldm.cascade.stage_b import StageB
from .ldm.cascade.stage_c import StageC
from .ldm.flux import model as flux_model
from .ldm.lightricks.model import LTXVModel
from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint
from .ldm.hydit.models import HunYuanDiT
from .ldm.lightricks.model import LTXVModel
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .model_management_types import ModelManageable
from .ops import Operations
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
class ModelType(Enum):
@ -109,6 +111,8 @@ class BaseModel(torch.nn.Module):
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device: torch.device = device
self.operations: Optional[Operations]
self.current_patcher: Optional[ModelManageable] = None
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
@ -137,6 +141,13 @@ class BaseModel(torch.nn.Module):
self.training = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return WrapperExecutor.new_class_executor(
self._apply_model,
self,
get_all_wrappers(WrappersMP.APPLY_MODEL, transformer_options)
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
@ -576,7 +587,6 @@ class IP2P(BaseModel):
return self.process_ip2p_image_in(image)
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@ -746,7 +756,13 @@ class Flux(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)
def concat_cond(self, **kwargs):
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
try:
# Handle Flux control loras dynamically changing the img_in weight.
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
except:
# Some cases like tensorrt might not have the weights accessible
num_channels = self.model_config.unet_config["in_channels"]
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
@ -765,7 +781,7 @@ class Flux(BaseModel):
if num_channels <= out_channels * 2:
return image
#inpaint model
# inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
@ -788,6 +804,7 @@ class Flux(BaseModel):
out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=AsymmDiTJoint)
@ -803,9 +820,10 @@ class GenmoMochi(BaseModel):
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=LTXVModel) #TODO
super().__init__(model_config, model_type, device=device, unet_model=LTXVModel) # TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -820,5 +838,9 @@ class LTXV(BaseModel):
if guiding_latent is not None:
out['guiding_latent'] = conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = conds.CONDConstant(kwargs.get("frame_rate", 25))
return out

View File

@ -1,2 +0,0 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename

View File

@ -1,238 +0,0 @@
# NOTE: This was an experiment and WILL BE REMOVED
from __future__ import annotations
import logging
import os
import re
import time
import traceback
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Any, Optional, Awaitable, Dict
import aiohttp
from ..cmd.folder_paths import folder_names_and_paths, get_folder_paths # pylint: disable=import-error
class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False
def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}
async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
model_url: str,
model_directory: str,
folder_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
progress_interval: float = 1.0) -> DownloadModelStatus:
"""
Download a model file from a given URL into the models directory.
Args:
model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]):
A function that makes an HTTP request. This makes it easier to mock in unit tests.
model_name (str):
The name of the model file to be downloaded. This will be the filename on disk.
model_url (str):
The URL from which to download the model.
model_directory (str):
The subdirectory within the main models directory where the model
should be saved (e.g., 'checkpoints', 'loras', etc.).
progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]):
An asynchronous function to call with progress updates.
folder_path (str);
Path to which model folder should be used as the root.
Returns:
DownloadModelStatus: The result of the download operation.
"""
if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)
if not model_directory in folder_names_and_paths:
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.",
False
)
if not folder_path in get_folder_paths(model_directory):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.",
False
)
file_path = create_model_path(model_name, folder_path)
existing_file = await check_file_exists(file_path, model_name, progress_callback)
if existing_file:
return existing_file
try:
logging.info(f"Downloading {model_name} from {model_url}")
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(model_name, status)
response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval)
except Exception as e:
logging.error(f"Error in downloading model: {e}")
return await handle_download_error(e, model_name, progress_callback)
def create_model_path(model_name: str, folder_path: str) -> str:
os.makedirs(folder_path, exist_ok=True)
file_path = os.path.join(folder_path, model_name)
# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(folder_path)
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {folder_path}/{model_name}")
return file_path
async def check_file_exists(file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]]
) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(model_name, status)
return status
return None
async def track_download_progress(response: aiohttp.ClientResponse,
file_path: str,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
interval: float = 1.0) -> DownloadModelStatus:
try:
total_size = int(response.headers.get('Content-Length', 0))
downloaded = 0
last_update_time = time.time()
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(model_name, status)
last_update_time = time.time()
temp_file_path = file_path + '.tmp'
with open(temp_file_path, 'wb') as f:
chunk_iterator = response.content.iter_chunked(8192)
while True:
try:
chunk = await chunk_iterator.__anext__()
except StopAsyncIteration:
break
f.write(chunk)
downloaded += len(chunk)
if time.time() - last_update_time >= interval:
await update_progress()
os.rename(temp_file_path, file_path)
await update_progress()
logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(model_name, status)
return status
except Exception as e:
logging.error(f"Error in track_download_progress: {e}")
logging.error(traceback.format_exc())
return await handle_download_error(e, model_name, progress_callback)
async def handle_download_error(e: Exception,
model_name: str,
progress_callback: Callable[[str, DownloadModelStatus], Any]
) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(model_name, status)
return status
def validate_filename(filename: str) -> bool:
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
if not filename.lower().endswith(('.sft', '.safetensors')):
return False
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False
# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False
# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False
# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False
# Ensure the filename isn't too long
if len(filename) > 255:
return False
return True

View File

@ -425,6 +425,9 @@ class LoadedModel:
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
@ -480,7 +483,7 @@ def _free_memory(memory_required, device, keep_loaded=[]):
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
@ -621,7 +624,7 @@ def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
if cur.is_dead():
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
@ -632,7 +635,7 @@ def cleanup_models_gc():
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
if cur.is_dead():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
@ -694,6 +697,10 @@ def maximum_vram_for_weights(device=None) -> int:
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@ -741,7 +748,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if weight_dtype == torch.float32:
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)

File diff suppressed because it is too large Load Diff

View File

@ -265,7 +265,7 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return time_snr_shift(self.shift, 1.0 - percent)
class StableCascadeSampling(ModelSamplingDiscrete):
@ -360,4 +360,4 @@ class ModelSamplingFlux(torch.nn.Module):
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
return flux_time_shift(self.shift, 1.0, 1.0 - percent)

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import torch
import os
@ -20,6 +21,7 @@ from .. import sd
from .. import utils
from .. import clip_vision as clip_vision_module
from .. import model_management
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
@ -43,16 +45,16 @@ def interrupt_processing(value=True):
interrupt_current_processing(value)
class CLIPTextEncode:
class CLIPTextEncode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s):
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"text": ("STRING", {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."})
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"
@ -61,9 +63,8 @@ class CLIPTextEncode:
def encode(self, clip, text):
tokens = clip.tokenize(text)
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
return (clip.encode_from_tokens_scheduled(tokens), )
class ConditioningCombine:
@classmethod
@ -644,9 +645,7 @@ class LoraLoader:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
else:
temp = self.loaded_lora
self.loaded_lora = None
del temp
if lora is None:
lora = utils.load_torch_file(lora_path, safe_load=True)
@ -1003,15 +1002,19 @@ class CLIPVisionEncode:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",)
"image": ("IMAGE",),
"crop": (["center", "none"],)
}}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode"
CATEGORY = "conditioning"
def encode(self, clip_vision, image):
output = clip_vision.encode_image(image)
def encode(self, clip_vision, image, crop):
crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
return (output,)
class StyleModelLoader:
@ -1036,14 +1039,19 @@ class StyleModelApply:
return {"required": {"conditioning": ("CONDITIONING", ),
"style_model": ("STYLE_MODEL", ),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"strength_type": (["multiply"], ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_stylemodel"
CATEGORY = "conditioning/style_model"
def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength, strength_type):
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
if strength_type == "multiply":
cond *= strength
c = []
for t in conditioning:
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]

View File

@ -321,7 +321,7 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)

156
comfy/patcher_extension.py Normal file
View File

@ -0,0 +1,156 @@
from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
ON_PRE_RUN = "on_pre_run"
ON_PREPARE_STATE = "on_prepare_state"
ON_APPLY_HOOKS = "on_apply_hooks"
ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
ON_INJECT_MODEL = "on_inject_model"
ON_EJECT_MODEL = "on_eject_model"
# callbacks dict is in the format:
# {"call_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)
def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
c = callbacks.setdefault(call_type, {}).setdefault(key, [])
c.append(callback)
def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
c_list.extend(callbacks.get(call_type, {}).get(key, []))
return c_list
def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
c_list = []
callbacks: dict[str, list] = transformer_options.get("callbacks", {})
for c in callbacks.get(call_type, {}).values():
c_list.extend(c)
return c_list
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
DIFFUSION_MODEL = "diffusion_model"
# wrappers dict is in the format:
# {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
@classmethod
def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
return {}
def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)
def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.setdefault("transformer_options", {})
wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
w.append(wrapper)
def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
return w_list
def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
if is_model_options:
transformer_options = transformer_options.get("transformer_options", {})
w_list = []
wrappers: dict[str, list] = transformer_options.get("wrappers", {})
for w in wrappers.get(wrapper_type, {}).values():
w_list.extend(w)
return w_list
class WrapperExecutor:
"""Handles call stack of wrappers around a function in an ordered manner."""
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
# NOTE: class_obj exists so that wrappers surrounding a class method can access
# the class instance at runtime via executor.class_obj
self.original = original
self.class_obj = class_obj
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)
def __call__(self, *args, **kwargs):
"""Calls the next wrapper or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor.execute(*args, **kwargs)
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, *args, **kwargs)
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)
@classmethod
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)
class PatcherInjection:
def __init__(self, inject: Callable, eject: Callable):
self.inject = inject
self.eject = eject
def copy_nested_dicts(input_dict: dict):
new_dict = input_dict.copy()
for key, value in input_dict.items():
if isinstance(value, dict):
new_dict[key] = copy_nested_dicts(value)
elif isinstance(value, list):
new_dict[key] = value.copy()
return new_dict
def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
if copy_dict1:
merged_dict = copy_nested_dicts(dict1)
else:
merged_dict = dict1
for key, value in dict2.items():
if isinstance(value, dict):
curr_value = merged_dict.setdefault(key, {})
merged_dict[key] = merge_nested_dicts(value, curr_value)
elif isinstance(value, list):
merged_dict.setdefault(key, []).extend(value)
else:
merged_dict[key] = value
return merged_dict

View File

@ -1,38 +1,88 @@
import torch
from . import model_management
from . import utils
import uuid
from .controlnet import ControlBase
from . import conds
from . import hooks
from . import model_management
from . import patcher_extension
from . import utils
from .model_base import BaseModel
from .model_patcher import ModelPatcher
def prepare_mask(noise_mask, shape, device):
return utils.reshape_mask(noise_mask, shape).to(device)
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
models += [c[model_type]]
if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, hooks_dict: dict[hooks.EnumHookType, dict[hooks.Hook, None]]):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook: hooks.Hook
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
if 'control' in c:
cnets.append(c['control'])
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
if cnet.extra_hooks is not None:
_list.append(cnet.extra_hooks)
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
get_extra_hooks_from_cnet(base_cnet, hooks_list)
extra_hooks = hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
with_type = hooks_dict.setdefault(hook.hook_type, {})
with_type[hook] = None
return hooks_dict
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) # TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
out.append(temp)
return out
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets = []
cnets: list[ControlBase] = []
gligen = []
add_models = []
hooks: dict[hooks.EnumHookType, dict[hooks.Hook, None]] = {}
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
get_hooks_from_cond(conds[k], hooks)
control_nets = set(cnets)
@ -43,9 +93,12 @@ def get_additional_models(conds, dtype):
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen
hook_models = [x.model for x in hooks.get(hooks.EnumHookType.AddModels, {}).keys()]
models = control_models + gligen + add_models + hook_models
return models, inference_memory
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
@ -53,10 +106,11 @@ def cleanup_additional_models(models):
m.cleanup()
def prepare_sampling(model, noise_shape, conds):
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device
real_model = None
real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
@ -64,6 +118,7 @@ def prepare_sampling(model, noise_shape, conds):
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
@ -72,3 +127,15 @@ def cleanup_models(conds, models):
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
# check for hooks in conds - if not registered, see if can be applied
hooks = {}
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
model_options["transformer_options"]["wrappers"] = patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, hooks.EnumWeightTarget.Model, model_options)

View File

@ -1,17 +1,25 @@
from .component_model.deprecation import _deprecate_method
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
import torch
from __future__ import annotations
import collections
from . import model_management
import math
import logging
import scipy.stats
import math
import numpy
import scipy.stats
import torch
from . import hooks
from . import model_management
from . import model_patcher
from . import patcher_extension
from . import sampler_helpers
from .component_model.deprecation import _deprecate_method
from .controlnet import ControlBase
from .extra_samplers import uni_pc
from .k_diffusion import sampling as k_diffusion_sampling
from .model_base import BaseModel
from .model_management_types import ModelOptions
from .model_patcher import ModelPatcher
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
@ -46,7 +54,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1:] == x_in.shape[2:])
assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@ -65,17 +73,18 @@ def get_area_and_mult(conds, x_in, timestep_in):
if area[len(dims) + i] != 0:
for t in range(rr):
m = mult.narrow(i + 2, t, 1)
m *= ((1.0/rr) * (t + 1))
m *= ((1.0 / rr) * (t + 1))
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
for t in range(rr):
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
m *= ((1.0/rr) * (t + 1))
m *= ((1.0 / rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
hooks = conds.get('hooks', None)
control = conds.get('control', None)
patches = None
@ -91,8 +100,9 @@ def get_area_and_mult(conds, x_in, timestep_in):
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks'])
return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks)
def cond_equal_size(c1, c2):
if c1 is c2:
@ -104,6 +114,7 @@ def cond_equal_size(c1, c2):
return False
return True
def can_concat_cond(c1, c2):
if c1.input_x.shape != c2.input_x.shape:
return False
@ -124,6 +135,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
c_crossattn = []
c_concat = []
@ -144,121 +156,200 @@ def cond_cat(c_list):
return out
def calc_cond_batch(model: BaseModel, conds, x_in: torch.Tensor, timestep: torch.Tensor, model_options: ModelOptions):
def finalize_default_conds(model: BaseModel, hooked_to_run: dict[hooks.HookGroup, list[tuple[tuple, int]]], default_conds: list[list[dict]], x_in, timestep):
# need to figure out remaining unmasked area for conds
default_mults = []
for _ in default_conds:
default_mults.append(torch.ones_like(x_in))
# look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond
for lora_hooks, to_run in hooked_to_run.items():
for cond_obj, i in to_run:
# if no default_cond for cond_type, do nothing
if len(default_conds[i]) == 0:
continue
area: list[int] = cond_obj.area
if area is not None:
curr_default_mult: torch.Tensor = default_mults[i]
dims = len(area) // 2
for i in range(dims):
curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i])
curr_default_mult -= cond_obj.mult
else:
default_mults[i] -= cond_obj.mult
# for each default_mult, ReLU to make negatives=0, and then check for any nonzeros
for i, mult in enumerate(default_mults):
# if no default_cond for cond type, do nothing
if len(default_conds[i]) == 0:
continue
torch.nn.functional.relu(mult, inplace=True)
# if mult is all zeros, then don't add default_cond
if torch.max(mult) == 0.0:
continue
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
p = p._replace(mult=mult)
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
executor = patcher_extension.WrapperExecutor.new_executor(
_calc_cond_batch,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
)
return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds, x_in: torch.Tensor, timestep: torch.Tensor, model_options: ModelOptions):
out_conds = []
out_counts = []
to_run = []
# separate conds by matching hooks
hooked_to_run: dict[hooks.HookGroup, list[tuple[tuple, int]]] = {}
default_conds = []
has_default_conds = False
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
to_run += [(p, i)]
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep)
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
model.current_patcher.prepare_state(timestep)
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp) // i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
transformer_options = model.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
if patches is not None:
# TODO: replace with merge_nested_dicts function
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
transformer_options["patches"] = cur_patches
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
transformer_options["patches"] = patches
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
@_deprecate_method(version="0.0.2", message="The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): # TODO: remove
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options:ModelOptions=None, cond=None, uncond=None):
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options: ModelOptions = None, cond=None, uncond=None):
if model_options is None:
model_options = {}
if "sampler_cfg_function" in model_options:
@ -275,9 +366,10 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
return cfg_result
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options:ModelOptions=None, seed=None):
# The main sampling function shared by all the samplers
# Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options: ModelOptions = None, seed=None):
if model_options is None:
model_options = {}
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
@ -289,9 +381,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
out = calc_cond_batch(model, conds, x, timestep, model_options)
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
args = {"conds": conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
out = fn(args)
out = fn(args)
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
@ -300,7 +392,8 @@ class KSamplerX0Inpaint:
def __init__(self, model, sigmas):
self.inner_model = model
self.sigmas = sigmas
def __call__(self, x, sigma, denoise_mask, model_options:ModelOptions=None, seed=None):
def __call__(self, x, sigma, denoise_mask, model_options: ModelOptions = None, seed=None):
if model_options is None:
model_options = {}
if denoise_mask is not None:
@ -313,6 +406,7 @@ class KSamplerX0Inpaint:
out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
@ -322,6 +416,7 @@ def simple_scheduler(model_sampling, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model_sampling, steps):
s = model_sampling
sigs = []
@ -339,6 +434,7 @@ def ddim_scheduler(model_sampling, steps):
sigs = sigs[::-1]
return torch.FloatTensor(sigs)
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
s = model_sampling
start = s.timestep(s.sigma_max)
@ -363,6 +459,7 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
return torch.FloatTensor(sigs)
# Implemented based on: https://arxiv.org/abs/2407.12173
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
total_timesteps = (len(model_sampling.sigmas) - 1)
@ -378,6 +475,7 @@ def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
sigs += [0.0]
return torch.FloatTensor(sigs)
# from: https://github.com/genmoai/models/blob/main/src/mochi_preview/infer.py#L41
def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, linear_steps=None):
if steps == 1:
@ -399,6 +497,7 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -422,6 +521,7 @@ def get_mask_aabb(masks):
return bounding_boxes, is_empty
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
@ -452,8 +552,8 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
if mask.shape[1:] != dims:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=dims, mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
if modified.get("set_area_to_bounds", False): # TODO: handle dim != 2
bounds = torch.max(torch.abs(mask), dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
@ -469,11 +569,13 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
modified['mask'] = mask
conditions[i] = modified
@_deprecate_method(version="0.3.2", message="WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
def resolve_areas_and_cond_masks(conditions, h, w, device):
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
def create_cond_with_same_area_if_none(conds, c): # TODO: handle dim != 2
if 'area' not in c:
return
@ -502,9 +604,10 @@ def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
return
out = c.copy()
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
out['model_conds'] = smallest['model_conds'].copy() # TODO: which fields should be copied?
conds += [out]
def calculate_start_end_timesteps(model, conds):
s = model.model_sampling
for t in range(len(conds)):
@ -512,10 +615,15 @@ def calculate_start_end_timesteps(model, conds):
timestep_start = None
timestep_end = None
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
# handle clip hook schedule, if needed
if 'clip_start_percent' in x:
timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0)))
timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0)))
else:
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None):
n = x.copy()
@ -525,6 +633,7 @@ def calculate_start_end_timesteps(model, conds):
n['timestep_end'] = timestep_end
conds[t] = n
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
@ -536,6 +645,7 @@ def pre_run_control(model, conds):
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
cond_other = []
@ -571,6 +681,7 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = n
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
for t in range(len(conds)):
x = conds[t]
@ -578,7 +689,7 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
params["device"] = device
params["noise"] = noise
default_width = None
if len(noise.shape) >= 4: #TODO: 8 multiple should be set by the model
if len(noise.shape) >= 4: # TODO: 8 multiple should be set by the model
default_width = noise.shape[3] * 8
params["width"] = params.get("width", default_width)
params["height"] = params.get("height", noise.shape[2] * 8)
@ -596,6 +707,7 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
conds[t] = x
return conds
class Sampler:
def sample(self):
pass
@ -605,11 +717,13 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
@ -620,7 +734,7 @@ class KSAMPLER(Sampler):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap, sigmas)
model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
if self.inpaint_options.get("random", False): # TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
@ -649,6 +763,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
sigma_min = sigmas[-2]
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, **extra_options):
@ -659,6 +774,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable, **extra_options)
sampler_function = dpm_adaptive_function
else:
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
@ -678,13 +794,19 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
# make sure each cond area has an opposite one with the same area
for k in conds:
for c in conds[k]:
for kk in conds:
if k != kk:
create_cond_with_same_area_if_none(conds[kk], c)
for k in conds:
for c in conds[k]:
if 'hooks' in c:
for hook in c['hooks'].hooks:
hook.initialize_timesteps(model)
for k in conds:
pre_run_control(model, conds[k])
@ -697,9 +819,46 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
return conds
def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]):
# determine which ControlNets have extra_hooks that should be combined with normal hooks
hook_replacement: dict[tuple[ControlBase, hooks.HookGroup], list[dict]] = {}
for k in conds:
for kk in conds[k]:
if 'control' in kk:
control: 'ControlBase' = kk['control']
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
hooks: hooks.HookGroup = kk.get('hooks', None)
to_replace = hook_replacement.setdefault((control, hooks), [])
to_replace.append(kk)
# if nothing to replace, do nothing
if len(hook_replacement) == 0:
return
# for optimal sampling performance, common ControlNets + hook combos should have identical hooks
# on the cond dicts
for key, conds_to_modify in hook_replacement.items():
control = key[0]
hooks = key[1]
hooks = hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks])
# if combined hooks are not None, set as new hooks for all relevant conds
if hooks is not None:
for cond in conds_to_modify:
cond['hooks'] = hooks
def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]):
hooks_set = set()
for k in conds:
for kk in conds[k]:
hooks_set.add(kk.get('hooks', None))
return len(hooks_set)
class CFGGuider:
def __init__(self, model_patcher):
self.model_patcher = model_patcher
self.model_patcher: 'ModelPatcher' = model_patcher
self.model_options = model_patcher.model_options
self.original_conds = {}
self.cfg = 1.0
@ -721,24 +880,22 @@ class CFGGuider:
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
if latent_image is not None and torch.count_nonzero(latent_image) > 0: # Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
extra_args = {"model_options": self.model_options, "seed":seed}
extra_args = {"model_options": model_patcher.create_model_options_clone(self.model_options), "seed": seed}
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
executor = patcher_extension.WrapperExecutor.new_class_executor(
sampler.sample,
sampler,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True)
)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
def outer_sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds)
device = self.model_patcher.load_device
@ -749,14 +906,48 @@ class CFGGuider:
latent_image = latent_image.to(device)
sigmas = sigmas.to(device)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.cleanup()
sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
del self.conds
del self.loaded_models
return output
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
if sigmas.shape[-1] == 0:
return latent_image
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
try:
orig_model_options = self.model_options
self.model_options = model_patcher.create_model_options_clone(self.model_options)
# if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step)
orig_hook_mode = self.model_patcher.hook_mode
if get_total_hook_groups_in_conds(self.conds) <= 1:
self.model_patcher.hook_mode = hooks.EnumHookMode.MinVram
sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options)
executor = patcher_extension.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
del self.conds
return output
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
cfg_guider = CFGGuider(model)
@ -765,7 +956,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = None
@ -791,6 +981,7 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
return sigmas
def sampler_object(name):
if name == "uni_pc":
sampler = KSAMPLER(uni_pc.sample_unipc)
@ -802,6 +993,7 @@ def sampler_object(name):
sampler = ksampler(name)
return sampler
class KSampler:
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
@ -842,7 +1034,7 @@ class KSampler:
if denoise <= 0.0:
self.sigmas = torch.FloatTensor([])
else:
new_steps = int(steps/denoise)
new_steps = int(steps / denoise)
sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):]

View File

@ -6,14 +6,13 @@ import os.path
from enum import Enum
from typing import Any, Optional
import comfy.ldm.flux.redux
import comfy.text_encoders.lt
import torch
import yaml
from . import clip_vision
from . import diffusers_convert
from . import gligen
from . import lora
from . import model_detection
from . import model_management
from . import model_patcher
@ -21,10 +20,11 @@ from . import model_sampling
from . import sd1_clip
from . import sdxl_clip
from . import utils
from . import lora
from . import hooks
from .ldm.audio.autoencoder import AudioOobleckVAE
from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.flux.redux import ReduxImageEncoder
from .ldm.genmo.vae import model as genmo_model
from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
@ -37,6 +37,7 @@ from .text_encoders import flux
from .text_encoders import genmo
from .text_encoders import hydit
from .text_encoders import long_clipl
from .text_encoders import lt
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
@ -107,9 +108,13 @@ class CLIP:
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher.hook_mode = hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self):
@ -119,6 +124,8 @@ class CLIP:
# cloning the tokenizer allows the vocab updates to work more idiomatically
n.tokenizer = self.tokenizer.clone()
n.layer_idx = self.layer_idx
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
@ -130,6 +137,69 @@ class CLIP:
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
pooled_dict["hooks"] = self.apply_hooks_to_conds
return pooled_dict
def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str] = {}, show_pbar=True):
all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = []
all_hooks = self.patcher.forced_hooks
if all_hooks is None or not self.use_clip_schedule:
# if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict
return_pooled = "unprojected" if unprojected else True
pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
cond = pooled_dict.pop("cond")
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
all_cond_pooled.append([cond, pooled_dict])
else:
scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule()
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False):
self.cond_stage_model.reset_clip_options()
@ -147,6 +217,7 @@ class CLIP:
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
self.add_hooks_to_dict(out)
return out
if return_pooled:
@ -470,7 +541,7 @@ def load_style_model(ckpt_path):
if "style_embedding" in keys:
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
model = ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
@ -577,8 +648,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
clip_target.clip = lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = lt.LTXVT5Tokenizer
else: # CLIPType.MOCHI
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = genmo.MochiT5Tokenizer

View File

@ -659,6 +659,15 @@ class Flux(supported_models_base.BASE):
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
@ -733,6 +742,6 @@ class LTXV(supported_models_base.BASE):
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(lt.LTXVT5Tokenizer, lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
models += [SVD_img2vid]

View File

@ -220,6 +220,11 @@ class T5Stack(torch.nn.Module):
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
past_bias = None
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.block) + intermediate_output
for i, l in enumerate(self.block):
x, past_bias = l(x, mask, past_bias, optimized_attention)
if i == intermediate_output:

39
comfy/validation.py Normal file
View File

@ -0,0 +1,39 @@
from __future__ import annotations
def validate_node_input(
received_type: str, input_type: str, strict: bool = False
) -> bool:
"""
received_type and input_type are both strings of the form "T1,T2,...".
If strict is True, the input_type must contain the received_type.
For example, if received_type is "STRING" and input_type is "STRING,INT",
this will return True. But if received_type is "STRING,INT" and input_type is
"INT", this will return False.
If strict is False, the input_type must have overlap with the received_type.
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
this will return True.
Supports pre-union type extension behaviour of ``__ne__`` overrides.
"""
# If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__`
if not received_type != input_type:
return True
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False
# Split the type strings into sets for comparison
received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(","))
if strict:
# In strict mode, all received types must be in the input types
return received_types.issubset(input_types)
else:
# In non-strict mode, there must be at least one type in common
return len(received_types.intersection(input_types)) > 0

View File

@ -1,103 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, c8 as useExtensionStore, u as useSettingStore, r as ref, o as onMounted, q as computed, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, z as unref, bS as script$1, A as createBaseVNode, x as createBlock, N as Fragment, O as renderList, a4 as toDisplayString, av as createTextVNode, bQ as script$3, j as createCommentVNode, D as script$4 } from "./index-bi78Y1IN.js";
import { s as script, a as script$2 } from "./index-ftUEqmu1.js";
import "./index-bCeMLtLM.js";
const _hoisted_1 = { class: "extension-panel" };
const _hoisted_2 = { class: "mt-4" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ExtensionPanel",
setup(__props) {
const extensionStore = useExtensionStore();
const settingStore = useSettingStore();
const editingEnabledExtensions = ref({});
onMounted(() => {
extensionStore.extensions.forEach((ext) => {
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
});
});
const changedExtensions = computed(() => {
return extensionStore.extensions.filter(
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
);
});
const hasChanges = computed(() => {
return changedExtensions.value.length > 0;
});
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
const editingDisabledExtensionNames = Object.entries(
editingEnabledExtensions.value
).filter(([_, enabled]) => !enabled).map(([name]) => name);
settingStore.set("Comfy.Extension.Disabled", [
...extensionStore.inactiveDisabledExtensionNames,
...editingDisabledExtensionNames
]);
}, "updateExtensionStatus");
const applyChanges = /* @__PURE__ */ __name(() => {
window.location.reload();
}, "applyChanges");
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createVNode(unref(script$2), {
value: unref(extensionStore).extensions,
stripedRows: "",
size: "small"
}, {
default: withCtx(() => [
createVNode(unref(script), {
field: "name",
header: _ctx.$t("extensionName"),
sortable: ""
}, null, 8, ["header"]),
createVNode(unref(script), { pt: {
bodyCell: "flex items-center justify-end"
} }, {
body: withCtx((slotProps) => [
createVNode(unref(script$1), {
modelValue: editingEnabledExtensions.value[slotProps.data.name],
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
onChange: updateExtensionStatus
}, null, 8, ["modelValue", "onUpdate:modelValue"])
]),
_: 1
})
]),
_: 1
}, 8, ["value"]),
createBaseVNode("div", _hoisted_2, [
hasChanges.value ? (openBlock(), createBlock(unref(script$3), {
key: 0,
severity: "info"
}, {
default: withCtx(() => [
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
return openBlock(), createElementBlock("li", {
key: ext.name
}, [
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
createTextVNode(" " + toDisplayString(ext.name), 1)
]);
}), 128))
])
]),
_: 1
})) : createCommentVNode("", true),
createVNode(unref(script$4), {
label: _ctx.$t("reloadToApplyChanges"),
icon: "pi pi-refresh",
onClick: applyChanges,
disabled: !hasChanges.value,
text: "",
fluid: "",
severity: "danger"
}, null, 8, ["label", "disabled"])
])
]);
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ExtensionPanel-C-oQXg_k.js.map

View File

@ -1 +0,0 @@
{"version":3,"file":"ExtensionPanel-C-oQXg_k.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <div class=\"extension-panel\">\n <DataTable :value=\"extensionStore.extensions\" stripedRows size=\"small\">\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n <div class=\"mt-4\">\n <Message v-if=\"hasChanges\" severity=\"info\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n </Message>\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n icon=\"pi pi-refresh\"\n @click=\"applyChanges\"\n :disabled=\"!hasChanges\"\n text\n fluid\n severity=\"danger\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;AAmDA,UAAM,iBAAiB,kBAAkB;AACzC,UAAM,eAAe,gBAAgB;AAE/B,UAAA,2BAA2B,IAA6B,EAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAC9C;AAAA,IAAA,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IACH,GAX8B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS,OAAO;AAAA,IACzB,GAHqB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

117
comfy/web/assets/ExtensionPanel-DsD42OtO.js generated vendored Normal file
View File

@ -0,0 +1,117 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, r as ref, c6 as FilterMatchMode, ca as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, g as openBlock, x as createBlock, y as withCtx, i as createVNode, c7 as SearchBox, z as unref, bT as script, A as createBaseVNode, h as createElementBlock, O as renderList, a6 as toDisplayString, aw as createTextVNode, N as Fragment, D as script$1, j as createCommentVNode, bV as script$3, c8 as _sfc_main$1 } from "./index-CoOvI8ZH.js";
import { s as script$2, a as script$4 } from "./index-DK6Kev7f.js";
import "./index-D4DWQPPQ.js";
const _hoisted_1 = { class: "flex justify-end" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ExtensionPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const extensionStore = useExtensionStore();
const settingStore = useSettingStore();
const editingEnabledExtensions = ref({});
onMounted(() => {
extensionStore.extensions.forEach((ext) => {
editingEnabledExtensions.value[ext.name] = extensionStore.isExtensionEnabled(ext.name);
});
});
const changedExtensions = computed(() => {
return extensionStore.extensions.filter(
(ext) => editingEnabledExtensions.value[ext.name] !== extensionStore.isExtensionEnabled(ext.name)
);
});
const hasChanges = computed(() => {
return changedExtensions.value.length > 0;
});
const updateExtensionStatus = /* @__PURE__ */ __name(() => {
const editingDisabledExtensionNames = Object.entries(
editingEnabledExtensions.value
).filter(([_, enabled]) => !enabled).map(([name]) => name);
settingStore.set("Comfy.Extension.Disabled", [
...extensionStore.inactiveDisabledExtensionNames,
...editingDisabledExtensionNames
]);
}, "updateExtensionStatus");
const applyChanges = /* @__PURE__ */ __name(() => {
window.location.reload();
}, "applyChanges");
return (_ctx, _cache) => {
return openBlock(), createBlock(_sfc_main$1, {
value: "Extension",
class: "extension-panel"
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchExtensions") + "..."
}, null, 8, ["modelValue", "placeholder"]),
hasChanges.value ? (openBlock(), createBlock(unref(script), {
key: 0,
severity: "info",
"pt:text": "w-full"
}, {
default: withCtx(() => [
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(changedExtensions.value, (ext) => {
return openBlock(), createElementBlock("li", {
key: ext.name
}, [
createBaseVNode("span", null, toDisplayString(unref(extensionStore).isExtensionEnabled(ext.name) ? "[-]" : "[+]"), 1),
createTextVNode(" " + toDisplayString(ext.name), 1)
]);
}), 128))
]),
createBaseVNode("div", _hoisted_1, [
createVNode(unref(script$1), {
label: _ctx.$t("reloadToApplyChanges"),
onClick: applyChanges,
outlined: "",
severity: "danger"
}, null, 8, ["label"])
])
]),
_: 1
})) : createCommentVNode("", true)
]),
default: withCtx(() => [
createVNode(unref(script$4), {
value: unref(extensionStore).extensions,
stripedRows: "",
size: "small",
filters: filters.value
}, {
default: withCtx(() => [
createVNode(unref(script$2), {
field: "name",
header: _ctx.$t("extensionName"),
sortable: ""
}, null, 8, ["header"]),
createVNode(unref(script$2), { pt: {
bodyCell: "flex items-center justify-end"
} }, {
body: withCtx((slotProps) => [
createVNode(unref(script$3), {
modelValue: editingEnabledExtensions.value[slotProps.data.name],
"onUpdate:modelValue": /* @__PURE__ */ __name(($event) => editingEnabledExtensions.value[slotProps.data.name] = $event, "onUpdate:modelValue"),
onChange: updateExtensionStatus
}, null, 8, ["modelValue", "onUpdate:modelValue"])
]),
_: 1
})
]),
_: 1
}, 8, ["value", "filters"])
]),
_: 1
});
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ExtensionPanel-DsD42OtO.js.map

1
comfy/web/assets/ExtensionPanel-DsD42OtO.js.map generated vendored Normal file
View File

@ -0,0 +1 @@
{"version":3,"file":"ExtensionPanel-DsD42OtO.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

File diff suppressed because one or more lines are too long

1
comfy/web/assets/GraphView-BW5soyxY.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -158,10 +158,10 @@
z-index: 9999;
}
[data-v-9eb975c3] .p-togglebutton::before {
[data-v-783f8efe] .p-togglebutton::before {
display: none
}
[data-v-9eb975c3] .p-togglebutton {
[data-v-783f8efe] .p-togglebutton {
position: relative;
flex-shrink: 0;
border-radius: 0px;
@ -169,14 +169,14 @@
padding-left: 0.5rem;
padding-right: 0.5rem
}
[data-v-9eb975c3] .p-togglebutton.p-togglebutton-checked {
[data-v-783f8efe] .p-togglebutton.p-togglebutton-checked {
border-bottom-width: 2px;
border-bottom-color: var(--p-button-text-primary-color)
}
[data-v-9eb975c3] .p-togglebutton-checked .close-button,[data-v-9eb975c3] .p-togglebutton:hover .close-button {
[data-v-783f8efe] .p-togglebutton-checked .close-button,[data-v-783f8efe] .p-togglebutton:hover .close-button {
visibility: visible
}
.status-indicator[data-v-9eb975c3] {
.status-indicator[data-v-783f8efe] {
position: absolute;
font-weight: 700;
font-size: 1.5rem;
@ -184,10 +184,10 @@
left: 50%;
transform: translate(-50%, -50%)
}
[data-v-9eb975c3] .p-togglebutton:hover .status-indicator {
[data-v-783f8efe] .p-togglebutton:hover .status-indicator {
display: none
}
[data-v-9eb975c3] .p-togglebutton .close-button {
[data-v-783f8efe] .p-togglebutton .close-button {
visibility: hidden
}

File diff suppressed because one or more lines are too long

1
comfy/web/assets/InstallView-C6UIhIu4.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,8 +0,0 @@
[data-v-2d8b3a76] .p-datatable-tbody > tr > td {
padding: 0.25rem;
min-height: 2rem
}
[data-v-2d8b3a76] .p-datatable-row-selected .actions,[data-v-2d8b3a76] .p-datatable-selectable-row:hover .actions {
visibility: visible
}

View File

@ -1,273 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, av as createTextVNode, a4 as toDisplayString, z as unref, az as script, j as createCommentVNode, r as ref, c5 as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aI as watchEffect, be as useToast, t as resolveDirective, c6 as SearchBox, A as createBaseVNode, D as script$2, x as createBlock, an as script$4, bi as withModifiers, bQ as script$5, aG as script$6, v as withDirectives, c0 as KeyComboImpl, c7 as KeybindingImpl, _ as _export_sfc } from "./index-bi78Y1IN.js";
import { s as script$1, a as script$3 } from "./index-ftUEqmu1.js";
import "./index-bCeMLtLM.js";
const _hoisted_1$1 = {
key: 0,
class: "px-2"
};
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
__name: "KeyComboDisplay",
props: {
keyCombo: {},
isModified: { type: Boolean, default: false }
},
setup(__props) {
const props = __props;
const keySequences = computed(() => props.keyCombo.getKeySequences());
return (_ctx, _cache) => {
return openBlock(), createElementBlock("span", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
return openBlock(), createElementBlock(Fragment, { key: index }, [
createVNode(unref(script), {
severity: _ctx.isModified ? "info" : "secondary"
}, {
default: withCtx(() => [
createTextVNode(toDisplayString(sequence), 1)
]),
_: 2
}, 1032, ["severity"]),
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
], 64);
}), 128))
]);
};
}
});
const _hoisted_1 = { class: "keybinding-panel" };
const _hoisted_2 = { class: "actions invisible flex flex-row" };
const _hoisted_3 = ["title"];
const _hoisted_4 = { key: 1 };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "KeybindingPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const keybindingStore = useKeybindingStore();
const commandStore = useCommandStore();
const commandsData = computed(() => {
return Object.values(commandStore.commands).map((command) => ({
id: command.id,
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
}));
});
const selectedCommandData = ref(null);
const editDialogVisible = ref(false);
const newBindingKeyCombo = ref(null);
const currentEditingCommand = ref(null);
const keybindingInput = ref(null);
const existingKeybindingOnCombo = computed(() => {
if (!currentEditingCommand.value) {
return null;
}
if (currentEditingCommand.value.keybinding?.combo?.equals(
newBindingKeyCombo.value
)) {
return null;
}
if (!newBindingKeyCombo.value) {
return null;
}
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
});
function editKeybinding(commandData) {
currentEditingCommand.value = commandData;
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
editDialogVisible.value = true;
}
__name(editKeybinding, "editKeybinding");
watchEffect(() => {
if (editDialogVisible.value) {
setTimeout(() => {
keybindingInput.value?.$el?.focus();
}, 300);
}
});
function removeKeybinding(commandData) {
if (commandData.keybinding) {
keybindingStore.unsetKeybinding(commandData.keybinding);
keybindingStore.persistUserKeybindings();
}
}
__name(removeKeybinding, "removeKeybinding");
function captureKeybinding(event) {
const keyCombo = KeyComboImpl.fromEvent(event);
newBindingKeyCombo.value = keyCombo;
}
__name(captureKeybinding, "captureKeybinding");
function cancelEdit() {
editDialogVisible.value = false;
currentEditingCommand.value = null;
newBindingKeyCombo.value = null;
}
__name(cancelEdit, "cancelEdit");
function saveKeybinding() {
if (currentEditingCommand.value && newBindingKeyCombo.value) {
const updated = keybindingStore.updateKeybindingOnCommand(
new KeybindingImpl({
commandId: currentEditingCommand.value.id,
combo: newBindingKeyCombo.value
})
);
if (updated) {
keybindingStore.persistUserKeybindings();
}
}
cancelEdit();
}
__name(saveKeybinding, "saveKeybinding");
const toast = useToast();
async function resetKeybindings() {
keybindingStore.resetKeybindings();
await keybindingStore.persistUserKeybindings();
toast.add({
severity: "info",
summary: "Info",
detail: "Keybindings reset",
life: 3e3
});
}
__name(resetKeybindings, "resetKeybindings");
return (_ctx, _cache) => {
const _directive_tooltip = resolveDirective("tooltip");
return openBlock(), createElementBlock("div", _hoisted_1, [
createVNode(unref(script$3), {
value: commandsData.value,
selection: selectedCommandData.value,
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
"global-filter-fields": ["id"],
filters: filters.value,
selectionMode: "single",
stripedRows: "",
pt: {
header: "px-0"
}
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchKeybindings") + "..."
}, null, 8, ["modelValue", "placeholder"])
]),
default: withCtx(() => [
createVNode(unref(script$1), {
field: "actions",
header: ""
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", _hoisted_2, [
createVNode(unref(script$2), {
icon: "pi pi-pencil",
class: "p-button-text",
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
}, null, 8, ["onClick"]),
createVNode(unref(script$2), {
icon: "pi pi-trash",
class: "p-button-text p-button-danger",
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
disabled: !slotProps.data.keybinding
}, null, 8, ["onClick", "disabled"])
])
]),
_: 1
}),
createVNode(unref(script$1), {
field: "id",
header: "Command ID",
sortable: "",
class: "max-w-64 2xl:max-w-full"
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", {
class: "overflow-hidden text-ellipsis whitespace-nowrap",
title: slotProps.data.id
}, toDisplayString(slotProps.data.id), 9, _hoisted_3)
]),
_: 1
}),
createVNode(unref(script$1), {
field: "keybinding",
header: "Keybinding"
}, {
body: withCtx((slotProps) => [
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
key: 0,
keyCombo: slotProps.data.keybinding.combo,
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_4, "-"))
]),
_: 1
})
]),
_: 1
}, 8, ["value", "selection", "filters"]),
createVNode(unref(script$6), {
class: "min-w-96",
visible: editDialogVisible.value,
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
modal: "",
header: currentEditingCommand.value?.id,
onHide: cancelEdit
}, {
footer: withCtx(() => [
createVNode(unref(script$2), {
label: "Save",
icon: "pi pi-check",
onClick: saveKeybinding,
disabled: !!existingKeybindingOnCombo.value,
autofocus: ""
}, null, 8, ["disabled"])
]),
default: withCtx(() => [
createBaseVNode("div", null, [
createVNode(unref(script$4), {
class: "mb-2 text-center",
ref_key: "keybindingInput",
ref: keybindingInput,
modelValue: newBindingKeyCombo.value?.toString() ?? "",
placeholder: "Press keys for new binding",
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
autocomplete: "off",
fluid: "",
invalid: !!existingKeybindingOnCombo.value
}, null, 8, ["modelValue", "invalid"]),
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
key: 0,
severity: "error"
}, {
default: withCtx(() => [
_cache[3] || (_cache[3] = createTextVNode(" Keybinding already exists on ")),
createVNode(unref(script), {
severity: "secondary",
value: existingKeybindingOnCombo.value.commandId
}, null, 8, ["value"])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
_: 1
}, 8, ["visible", "header"]),
withDirectives(createVNode(unref(script$2), {
class: "mt-4",
label: _ctx.$t("reset"),
icon: "pi pi-trash",
severity: "danger",
fluid: "",
text: "",
onClick: resetKeybindings
}, null, 8, ["label"]), [
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
])
]);
};
}
});
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-2d8b3a76"]]);
export {
KeybindingPanel as default
};
//# sourceMappingURL=KeybindingPanel-BlBHQABM.js.map

File diff suppressed because one or more lines are too long

8
comfy/web/assets/KeybindingPanel-C-7KE-Kw.css generated vendored Normal file
View File

@ -0,0 +1,8 @@
[data-v-9d7e362e] .p-datatable-tbody > tr > td {
padding: 0.25rem;
min-height: 2rem
}
[data-v-9d7e362e] .p-datatable-row-selected .actions,[data-v-9d7e362e] .p-datatable-selectable-row:hover .actions {
visibility: visible
}

279
comfy/web/assets/KeybindingPanel-lcJrxHwZ.js generated vendored Normal file
View File

@ -0,0 +1,279 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, q as computed, g as openBlock, h as createElementBlock, N as Fragment, O as renderList, i as createVNode, y as withCtx, aw as createTextVNode, a6 as toDisplayString, z as unref, aA as script, j as createCommentVNode, r as ref, c6 as FilterMatchMode, M as useKeybindingStore, F as useCommandStore, aJ as watchEffect, bg as useToast, t as resolveDirective, x as createBlock, c7 as SearchBox, A as createBaseVNode, D as script$2, ao as script$4, bk as withModifiers, bT as script$5, aH as script$6, v as withDirectives, c8 as _sfc_main$2, P as pushScopeId, Q as popScopeId, c1 as KeyComboImpl, c9 as KeybindingImpl, _ as _export_sfc } from "./index-CoOvI8ZH.js";
import { s as script$1, a as script$3 } from "./index-DK6Kev7f.js";
import "./index-D4DWQPPQ.js";
const _hoisted_1$1 = {
key: 0,
class: "px-2"
};
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
__name: "KeyComboDisplay",
props: {
keyCombo: {},
isModified: { type: Boolean, default: false }
},
setup(__props) {
const props = __props;
const keySequences = computed(() => props.keyCombo.getKeySequences());
return (_ctx, _cache) => {
return openBlock(), createElementBlock("span", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(keySequences.value, (sequence, index) => {
return openBlock(), createElementBlock(Fragment, { key: index }, [
createVNode(unref(script), {
severity: _ctx.isModified ? "info" : "secondary"
}, {
default: withCtx(() => [
createTextVNode(toDisplayString(sequence), 1)
]),
_: 2
}, 1032, ["severity"]),
index < keySequences.value.length - 1 ? (openBlock(), createElementBlock("span", _hoisted_1$1, "+")) : createCommentVNode("", true)
], 64);
}), 128))
]);
};
}
});
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-9d7e362e"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "actions invisible flex flex-row" };
const _hoisted_2 = ["title"];
const _hoisted_3 = { key: 1 };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "KeybindingPanel",
setup(__props) {
const filters = ref({
global: { value: "", matchMode: FilterMatchMode.CONTAINS }
});
const keybindingStore = useKeybindingStore();
const commandStore = useCommandStore();
const commandsData = computed(() => {
return Object.values(commandStore.commands).map((command) => ({
id: command.id,
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
}));
});
const selectedCommandData = ref(null);
const editDialogVisible = ref(false);
const newBindingKeyCombo = ref(null);
const currentEditingCommand = ref(null);
const keybindingInput = ref(null);
const existingKeybindingOnCombo = computed(() => {
if (!currentEditingCommand.value) {
return null;
}
if (currentEditingCommand.value.keybinding?.combo?.equals(
newBindingKeyCombo.value
)) {
return null;
}
if (!newBindingKeyCombo.value) {
return null;
}
return keybindingStore.getKeybinding(newBindingKeyCombo.value);
});
function editKeybinding(commandData) {
currentEditingCommand.value = commandData;
newBindingKeyCombo.value = commandData.keybinding ? commandData.keybinding.combo : null;
editDialogVisible.value = true;
}
__name(editKeybinding, "editKeybinding");
watchEffect(() => {
if (editDialogVisible.value) {
setTimeout(() => {
keybindingInput.value?.$el?.focus();
}, 300);
}
});
function removeKeybinding(commandData) {
if (commandData.keybinding) {
keybindingStore.unsetKeybinding(commandData.keybinding);
keybindingStore.persistUserKeybindings();
}
}
__name(removeKeybinding, "removeKeybinding");
function captureKeybinding(event) {
const keyCombo = KeyComboImpl.fromEvent(event);
newBindingKeyCombo.value = keyCombo;
}
__name(captureKeybinding, "captureKeybinding");
function cancelEdit() {
editDialogVisible.value = false;
currentEditingCommand.value = null;
newBindingKeyCombo.value = null;
}
__name(cancelEdit, "cancelEdit");
function saveKeybinding() {
if (currentEditingCommand.value && newBindingKeyCombo.value) {
const updated = keybindingStore.updateKeybindingOnCommand(
new KeybindingImpl({
commandId: currentEditingCommand.value.id,
combo: newBindingKeyCombo.value
})
);
if (updated) {
keybindingStore.persistUserKeybindings();
}
}
cancelEdit();
}
__name(saveKeybinding, "saveKeybinding");
const toast = useToast();
async function resetKeybindings() {
keybindingStore.resetKeybindings();
await keybindingStore.persistUserKeybindings();
toast.add({
severity: "info",
summary: "Info",
detail: "Keybindings reset",
life: 3e3
});
}
__name(resetKeybindings, "resetKeybindings");
return (_ctx, _cache) => {
const _directive_tooltip = resolveDirective("tooltip");
return openBlock(), createBlock(_sfc_main$2, {
value: "Keybinding",
class: "keybinding-panel"
}, {
header: withCtx(() => [
createVNode(SearchBox, {
modelValue: filters.value["global"].value,
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => filters.value["global"].value = $event),
placeholder: _ctx.$t("searchKeybindings") + "..."
}, null, 8, ["modelValue", "placeholder"])
]),
default: withCtx(() => [
createVNode(unref(script$3), {
value: commandsData.value,
selection: selectedCommandData.value,
"onUpdate:selection": _cache[1] || (_cache[1] = ($event) => selectedCommandData.value = $event),
"global-filter-fields": ["id"],
filters: filters.value,
selectionMode: "single",
stripedRows: "",
pt: {
header: "px-0"
}
}, {
default: withCtx(() => [
createVNode(unref(script$1), {
field: "actions",
header: ""
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", _hoisted_1, [
createVNode(unref(script$2), {
icon: "pi pi-pencil",
class: "p-button-text",
onClick: /* @__PURE__ */ __name(($event) => editKeybinding(slotProps.data), "onClick")
}, null, 8, ["onClick"]),
createVNode(unref(script$2), {
icon: "pi pi-trash",
class: "p-button-text p-button-danger",
onClick: /* @__PURE__ */ __name(($event) => removeKeybinding(slotProps.data), "onClick"),
disabled: !slotProps.data.keybinding
}, null, 8, ["onClick", "disabled"])
])
]),
_: 1
}),
createVNode(unref(script$1), {
field: "id",
header: "Command ID",
sortable: "",
class: "max-w-64 2xl:max-w-full"
}, {
body: withCtx((slotProps) => [
createBaseVNode("div", {
class: "overflow-hidden text-ellipsis whitespace-nowrap",
title: slotProps.data.id
}, toDisplayString(slotProps.data.id), 9, _hoisted_2)
]),
_: 1
}),
createVNode(unref(script$1), {
field: "keybinding",
header: "Keybinding"
}, {
body: withCtx((slotProps) => [
slotProps.data.keybinding ? (openBlock(), createBlock(_sfc_main$1, {
key: 0,
keyCombo: slotProps.data.keybinding.combo,
isModified: unref(keybindingStore).isCommandKeybindingModified(slotProps.data.id)
}, null, 8, ["keyCombo", "isModified"])) : (openBlock(), createElementBlock("span", _hoisted_3, "-"))
]),
_: 1
})
]),
_: 1
}, 8, ["value", "selection", "filters"]),
createVNode(unref(script$6), {
class: "min-w-96",
visible: editDialogVisible.value,
"onUpdate:visible": _cache[2] || (_cache[2] = ($event) => editDialogVisible.value = $event),
modal: "",
header: currentEditingCommand.value?.id,
onHide: cancelEdit
}, {
footer: withCtx(() => [
createVNode(unref(script$2), {
label: "Save",
icon: "pi pi-check",
onClick: saveKeybinding,
disabled: !!existingKeybindingOnCombo.value,
autofocus: ""
}, null, 8, ["disabled"])
]),
default: withCtx(() => [
createBaseVNode("div", null, [
createVNode(unref(script$4), {
class: "mb-2 text-center",
ref_key: "keybindingInput",
ref: keybindingInput,
modelValue: newBindingKeyCombo.value?.toString() ?? "",
placeholder: "Press keys for new binding",
onKeydown: withModifiers(captureKeybinding, ["stop", "prevent"]),
autocomplete: "off",
fluid: "",
invalid: !!existingKeybindingOnCombo.value
}, null, 8, ["modelValue", "invalid"]),
existingKeybindingOnCombo.value ? (openBlock(), createBlock(unref(script$5), {
key: 0,
severity: "error"
}, {
default: withCtx(() => [
createTextVNode(" Keybinding already exists on "),
createVNode(unref(script), {
severity: "secondary",
value: existingKeybindingOnCombo.value.commandId
}, null, 8, ["value"])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
_: 1
}, 8, ["visible", "header"]),
withDirectives(createVNode(unref(script$2), {
class: "mt-4",
label: _ctx.$t("reset"),
icon: "pi pi-trash",
severity: "danger",
fluid: "",
text: "",
onClick: resetKeybindings
}, null, 8, ["label"]), [
[_directive_tooltip, _ctx.$t("resetKeybindingsTooltip")]
])
]),
_: 1
});
};
}
});
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-9d7e362e"]]);
export {
KeybindingPanel as default
};
//# sourceMappingURL=KeybindingPanel-lcJrxHwZ.js.map

1
comfy/web/assets/KeybindingPanel-lcJrxHwZ.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

150
comfy/web/assets/ServerConfigPanel-x68ubY-c.js generated vendored Normal file
View File

@ -0,0 +1,150 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { A as createBaseVNode, g as openBlock, h as createElementBlock, aU as markRaw, d as defineComponent, u as useSettingStore, bw as storeToRefs, w as watch, cy as useCopyToClipboard, x as createBlock, y as withCtx, z as unref, bT as script, a6 as toDisplayString, O as renderList, N as Fragment, i as createVNode, D as script$1, j as createCommentVNode, bI as script$2, cz as formatCamelCase, cA as FormItem, c8 as _sfc_main$1, bN as electronAPI } from "./index-CoOvI8ZH.js";
import { u as useServerConfigStore } from "./serverConfigStore-cctR8PGG.js";
const _hoisted_1$1 = {
viewBox: "0 0 24 24",
width: "1.2em",
height: "1.2em"
};
const _hoisted_2$1 = /* @__PURE__ */ createBaseVNode("path", {
fill: "none",
stroke: "currentColor",
"stroke-linecap": "round",
"stroke-linejoin": "round",
"stroke-width": "2",
d: "m4 17l6-6l-6-6m8 14h8"
}, null, -1);
const _hoisted_3$1 = [
_hoisted_2$1
];
function render(_ctx, _cache) {
return openBlock(), createElementBlock("svg", _hoisted_1$1, [..._hoisted_3$1]);
}
__name(render, "render");
const __unplugin_components_0 = markRaw({ name: "lucide-terminal", render });
const _hoisted_1 = { class: "flex flex-col gap-2" };
const _hoisted_2 = { class: "flex justify-end gap-2" };
const _hoisted_3 = { class: "flex items-center justify-between" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ServerConfigPanel",
setup(__props) {
const settingStore = useSettingStore();
const serverConfigStore = useServerConfigStore();
const {
serverConfigsByCategory,
serverConfigValues,
launchArgs,
commandLineArgs,
modifiedConfigs
} = storeToRefs(serverConfigStore);
const revertChanges = /* @__PURE__ */ __name(() => {
serverConfigStore.revertChanges();
}, "revertChanges");
const restartApp = /* @__PURE__ */ __name(() => {
electronAPI().restartApp();
}, "restartApp");
watch(launchArgs, (newVal) => {
settingStore.set("Comfy.Server.LaunchArgs", newVal);
});
watch(serverConfigValues, (newVal) => {
settingStore.set("Comfy.Server.ServerConfigValues", newVal);
});
const { copyToClipboard } = useCopyToClipboard();
const copyCommandLineArgs = /* @__PURE__ */ __name(async () => {
await copyToClipboard(commandLineArgs.value);
}, "copyCommandLineArgs");
return (_ctx, _cache) => {
const _component_i_lucide58terminal = __unplugin_components_0;
return openBlock(), createBlock(_sfc_main$1, {
value: "Server-Config",
class: "server-config-panel"
}, {
header: withCtx(() => [
createBaseVNode("div", _hoisted_1, [
unref(modifiedConfigs).length > 0 ? (openBlock(), createBlock(unref(script), {
key: 0,
severity: "info",
"pt:text": "w-full"
}, {
default: withCtx(() => [
createBaseVNode("p", null, toDisplayString(_ctx.$t("serverConfig.modifiedConfigs")), 1),
createBaseVNode("ul", null, [
(openBlock(true), createElementBlock(Fragment, null, renderList(unref(modifiedConfigs), (config) => {
return openBlock(), createElementBlock("li", {
key: config.id
}, toDisplayString(config.name) + ": " + toDisplayString(config.initialValue) + " → " + toDisplayString(config.value), 1);
}), 128))
]),
createBaseVNode("div", _hoisted_2, [
createVNode(unref(script$1), {
label: _ctx.$t("serverConfig.revertChanges"),
onClick: revertChanges,
outlined: ""
}, null, 8, ["label"]),
createVNode(unref(script$1), {
label: _ctx.$t("serverConfig.restart"),
onClick: restartApp,
outlined: "",
severity: "danger"
}, null, 8, ["label"])
])
]),
_: 1
})) : createCommentVNode("", true),
unref(commandLineArgs) ? (openBlock(), createBlock(unref(script), {
key: 1,
severity: "secondary",
"pt:text": "w-full"
}, {
icon: withCtx(() => [
createVNode(_component_i_lucide58terminal, { class: "text-xl font-bold" })
]),
default: withCtx(() => [
createBaseVNode("div", _hoisted_3, [
createBaseVNode("p", null, toDisplayString(unref(commandLineArgs)), 1),
createVNode(unref(script$1), {
icon: "pi pi-clipboard",
onClick: copyCommandLineArgs,
severity: "secondary",
text: ""
})
])
]),
_: 1
})) : createCommentVNode("", true)
])
]),
default: withCtx(() => [
(openBlock(true), createElementBlock(Fragment, null, renderList(Object.entries(unref(serverConfigsByCategory)), ([label, items], i) => {
return openBlock(), createElementBlock("div", { key: label }, [
i > 0 ? (openBlock(), createBlock(unref(script$2), { key: 0 })) : createCommentVNode("", true),
createBaseVNode("h3", null, toDisplayString(unref(formatCamelCase)(label)), 1),
(openBlock(true), createElementBlock(Fragment, null, renderList(items, (item) => {
return openBlock(), createElementBlock("div", {
key: item.name,
class: "flex items-center mb-4"
}, [
createVNode(FormItem, {
item,
formValue: item.value,
"onUpdate:formValue": /* @__PURE__ */ __name(($event) => item.value = $event, "onUpdate:formValue"),
id: item.id,
labelClass: {
"text-highlight": item.initialValue !== item.value
}
}, null, 8, ["item", "formValue", "onUpdate:formValue", "id", "labelClass"])
]);
}), 128))
]);
}), 128))
]),
_: 1
});
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ServerConfigPanel-x68ubY-c.js.map

1
comfy/web/assets/ServerConfigPanel-x68ubY-c.js.map generated vendored Normal file
View File

@ -0,0 +1 @@
{"version":3,"file":"ServerConfigPanel-x68ubY-c.js","sources":["../../src/components/dialog/content/setting/ServerConfigPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Server-Config\" class=\"server-config-panel\">\n <template #header>\n <div class=\"flex flex-col gap-2\">\n <Message\n v-if=\"modifiedConfigs.length > 0\"\n severity=\"info\"\n pt:text=\"w-full\"\n >\n <p>\n {{ $t('serverConfig.modifiedConfigs') }}\n </p>\n <ul>\n <li v-for=\"config in modifiedConfigs\" :key=\"config.id\">\n {{ config.name }}: {{ config.initialValue }} → {{ config.value }}\n </li>\n </ul>\n <div class=\"flex justify-end gap-2\">\n <Button\n :label=\"$t('serverConfig.revertChanges')\"\n @click=\"revertChanges\"\n outlined\n />\n <Button\n :label=\"$t('serverConfig.restart')\"\n @click=\"restartApp\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n <Message v-if=\"commandLineArgs\" severity=\"secondary\" pt:text=\"w-full\">\n <template #icon>\n <i-lucide:terminal class=\"text-xl font-bold\" />\n </template>\n <div class=\"flex items-center justify-between\">\n <p>{{ commandLineArgs }}</p>\n <Button\n icon=\"pi pi-clipboard\"\n @click=\"copyCommandLineArgs\"\n severity=\"secondary\"\n text\n />\n </div>\n </Message>\n </div>\n </template>\n <div\n v-for=\"([label, items], i) in Object.entries(serverConfigsByCategory)\"\n :key=\"label\"\n >\n <Divider v-if=\"i > 0\" />\n <h3>{{ formatCamelCase(label) }}</h3>\n <div\n v-for=\"item in items\"\n :key=\"item.name\"\n class=\"flex items-center mb-4\"\n >\n <FormItem\n :item=\"item\"\n v-model:formValue=\"item.value\"\n :id=\"item.id\"\n :labelClass=\"{\n 'text-highlight': item.initialValue !== item.value\n }\"\n />\n </div>\n </div>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport Divider from 'primevue/divider'\nimport FormItem from '@/components/common/FormItem.vue'\nimport PanelTemplate from './PanelTemplate.vue'\nimport { formatCamelCase } from '@/utils/formatUtil'\nimport { useServerConfigStore } from '@/stores/serverConfigStore'\nimport { storeToRefs } from 'pinia'\nimport { electronAPI } from '@/utils/envUtil'\nimport { useSettingStore } from '@/stores/settingStore'\nimport { watch } from 'vue'\nimport { useCopyToClipboard } from '@/hooks/clipboardHooks'\n\nconst settingStore = useSettingStore()\nconst serverConfigStore = useServerConfigStore()\nconst {\n serverConfigsByCategory,\n serverConfigValues,\n launchArgs,\n commandLineArgs,\n modifiedConfigs\n} = storeToRefs(serverConfigStore)\n\nconst revertChanges = () => {\n serverConfigStore.revertChanges()\n}\n\nconst restartApp = () => {\n electronAPI().restartApp()\n}\n\nwatch(launchArgs, (newVal) => {\n settingStore.set('Comfy.Server.LaunchArgs', newVal)\n})\n\nwatch(serverConfigValues, (newVal) => {\n settingStore.set('Comfy.Server.ServerConfigValues', newVal)\n})\n\nconst { copyToClipboard } = useCopyToClipboard()\nconst copyCommandLineArgs = async () => {\n await copyToClipboard(commandLineArgs.value)\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAqFA,UAAM,eAAe;AACrB,UAAM,oBAAoB;AACpB,UAAA;AAAA,MACJ;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,IAAA,IACE,YAAY,iBAAiB;AAEjC,UAAM,gBAAgB,6BAAM;AAC1B,wBAAkB,cAAc;AAAA,IAAA,GADZ;AAItB,UAAM,aAAa,6BAAM;AACvB,kBAAA,EAAc;IAAW,GADR;AAIb,UAAA,YAAY,CAAC,WAAW;AACf,mBAAA,IAAI,2BAA2B,MAAM;AAAA,IAAA,CACnD;AAEK,UAAA,oBAAoB,CAAC,WAAW;AACvB,mBAAA,IAAI,mCAAmC,MAAM;AAAA,IAAA,CAC3D;AAEK,UAAA,EAAE,oBAAoB;AAC5B,UAAM,sBAAsB,mCAAY;AAChC,YAAA,gBAAgB,gBAAgB,KAAK;AAAA,IAAA,GADjB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

View File

@ -1,102 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, r as ref, o as onMounted, w as watch, I as onBeforeUnmount, g as openBlock, h as createElementBlock, i as createVNode, y as withCtx, A as createBaseVNode, a4 as toDisplayString, z as unref, bJ as script, bK as electronAPI } from "./index-bi78Y1IN.js";
import { o as oe, r as rr } from "./index-Ba5g1c58.js";
const _hoisted_1$1 = { class: "p-terminal rounded-none h-full w-full" };
const _hoisted_2$1 = { class: "px-4 whitespace-pre-wrap" };
const _sfc_main$1 = /* @__PURE__ */ defineComponent({
__name: "LogTerminal",
props: {
fetchLogs: { type: Function },
fetchInterval: {}
},
setup(__props) {
const props = __props;
const log = ref("");
const scrollPanelRef = ref(null);
const scrolledToBottom = ref(false);
let intervalId = 0;
onMounted(async () => {
const element = scrollPanelRef.value?.$el;
const scrollContainer = element?.querySelector(".p-scrollpanel-content");
if (scrollContainer) {
scrollContainer.addEventListener("scroll", () => {
scrolledToBottom.value = scrollContainer.scrollTop + scrollContainer.clientHeight === scrollContainer.scrollHeight;
});
}
const scrollToBottom = /* @__PURE__ */ __name(() => {
if (scrollContainer) {
scrollContainer.scrollTop = scrollContainer.scrollHeight;
}
}, "scrollToBottom");
watch(log, () => {
if (scrolledToBottom.value) {
scrollToBottom();
}
});
const fetchLogs = /* @__PURE__ */ __name(async () => {
log.value = await props.fetchLogs();
}, "fetchLogs");
await fetchLogs();
scrollToBottom();
intervalId = window.setInterval(fetchLogs, props.fetchInterval);
});
onBeforeUnmount(() => {
window.clearInterval(intervalId);
});
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1$1, [
createVNode(unref(script), {
class: "h-full w-full",
ref_key: "scrollPanelRef",
ref: scrollPanelRef
}, {
default: withCtx(() => [
createBaseVNode("pre", _hoisted_2$1, toDisplayString(log.value), 1)
]),
_: 1
}, 512)
]);
};
}
});
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
const _hoisted_2 = { class: "text-2xl font-bold" };
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ServerStartView",
setup(__props) {
const electron = electronAPI();
const status = ref(oe.INITIAL_STATE);
const logs = ref([]);
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
status.value = newStatus;
logs.value = [];
}, "updateProgress");
const addLogMessage = /* @__PURE__ */ __name((message) => {
logs.value = [...logs.value, message];
}, "addLogMessage");
const fetchLogs = /* @__PURE__ */ __name(async () => {
return logs.value.join("\n");
}, "fetchLogs");
onMounted(() => {
electron.sendReady();
electron.onProgressUpdate(updateProgress);
electron.onLogMessage((message) => {
addLogMessage(message);
});
});
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createBaseVNode("h2", _hoisted_2, toDisplayString(unref(rr)[status.value]), 1),
createVNode(_sfc_main$1, {
"fetch-logs": fetchLogs,
"fetch-interval": 500
})
]);
};
}
});
export {
_sfc_main as default
};
//# sourceMappingURL=ServerStartView-Bb00SCJ5.js.map

View File

@ -1 +0,0 @@
{"version":3,"file":"ServerStartView-Bb00SCJ5.js","sources":["../../src/components/common/LogTerminal.vue","../../src/views/ServerStartView.vue"],"sourcesContent":["<!-- A simple read-only terminal component that displays logs. -->\n<template>\n <div class=\"p-terminal rounded-none h-full w-full\">\n <ScrollPanel class=\"h-full w-full\" ref=\"scrollPanelRef\">\n <pre class=\"px-4 whitespace-pre-wrap\">{{ log }}</pre>\n </ScrollPanel>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport ScrollPanel from 'primevue/scrollpanel'\nimport { onBeforeUnmount, onMounted, ref, watch } from 'vue'\n\nconst props = defineProps<{\n fetchLogs: () => Promise<string>\n fetchInterval: number\n}>()\n\nconst log = ref<string>('')\nconst scrollPanelRef = ref<InstanceType<typeof ScrollPanel> | null>(null)\n/**\n * Whether the user has scrolled to the bottom of the terminal.\n * This is used to prevent the terminal from scrolling to the bottom\n * when new logs are fetched.\n */\nconst scrolledToBottom = ref(false)\n\nlet intervalId: number = 0\n\nonMounted(async () => {\n const element = scrollPanelRef.value?.$el\n const scrollContainer = element?.querySelector('.p-scrollpanel-content')\n\n if (scrollContainer) {\n scrollContainer.addEventListener('scroll', () => {\n scrolledToBottom.value =\n scrollContainer.scrollTop + scrollContainer.clientHeight ===\n scrollContainer.scrollHeight\n })\n }\n\n const scrollToBottom = () => {\n if (scrollContainer) {\n scrollContainer.scrollTop = scrollContainer.scrollHeight\n }\n }\n\n watch(log, () => {\n if (scrolledToBottom.value) {\n scrollToBottom()\n }\n })\n\n const fetchLogs = async () => {\n log.value = await props.fetchLogs()\n }\n\n await fetchLogs()\n scrollToBottom()\n intervalId = window.setInterval(fetchLogs, props.fetchInterval)\n})\n\nonBeforeUnmount(() => {\n window.clearInterval(intervalId)\n})\n</script>\n","<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">{{ ProgressMessages[status] }}</h2>\n <LogTerminal :fetch-logs=\"fetchLogs\" :fetch-interval=\"500\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, onMounted } from 'vue'\nimport LogTerminal from '@/components/common/LogTerminal.vue'\nimport {\n ProgressStatus,\n ProgressMessages\n} from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\n\nconst electron = electronAPI()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst logs = ref<string[]>([])\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n logs.value = [] // Clear logs when status changes\n}\n\nconst addLogMessage = (message: string) => {\n logs.value = [...logs.value, message]\n}\n\nconst fetchLogs = async () => {\n return logs.value.join('\\n')\n}\n\nonMounted(() => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electron.onLogMessage((message: string) => {\n addLogMessage(message)\n })\n})\n</script>\n"],"names":["ProgressStatus"],"mappings":";;;;;;;;;;;;;AAaA,UAAM,QAAQ;AAKR,UAAA,MAAM,IAAY,EAAE;AACpB,UAAA,iBAAiB,IAA6C,IAAI;AAMlE,UAAA,mBAAmB,IAAI,KAAK;AAElC,QAAI,aAAqB;AAEzB,cAAU,YAAY;AACd,YAAA,UAAU,eAAe,OAAO;AAChC,YAAA,kBAAkB,SAAS,cAAc,wBAAwB;AAEvE,UAAI,iBAAiB;AACH,wBAAA,iBAAiB,UAAU,MAAM;AAC/C,2BAAiB,QACf,gBAAgB,YAAY,gBAAgB,iBAC5C,gBAAgB;AAAA,QAAA,CACnB;AAAA,MAAA;AAGH,YAAM,iBAAiB,6BAAM;AAC3B,YAAI,iBAAiB;AACnB,0BAAgB,YAAY,gBAAgB;AAAA,QAAA;AAAA,MAEhD,GAJuB;AAMvB,YAAM,KAAK,MAAM;AACf,YAAI,iBAAiB,OAAO;AACX,yBAAA;AAAA,QAAA;AAAA,MACjB,CACD;AAED,YAAM,YAAY,mCAAY;AACxB,YAAA,QAAQ,MAAM,MAAM,UAAU;AAAA,MACpC,GAFkB;AAIlB,YAAM,UAAU;AACD,qBAAA;AACf,mBAAa,OAAO,YAAY,WAAW,MAAM,aAAa;AAAA,IAAA,CAC/D;AAED,oBAAgB,MAAM;AACpB,aAAO,cAAc,UAAU;AAAA,IAAA,CAChC;;;;;;;;;;;;;;;;;;;;;;AC9CD,UAAM,WAAW,YAAY;AAEvB,UAAA,SAAS,IAAoBA,GAAe,aAAa;AACzD,UAAA,OAAO,IAAc,EAAE;AAE7B,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AACf,WAAK,QAAQ,CAAC;AAAA,IAChB,GAHuB;AAKjB,UAAA,gBAAgB,wBAAC,YAAoB;AACzC,WAAK,QAAQ,CAAC,GAAG,KAAK,OAAO,OAAO;AAAA,IACtC,GAFsB;AAItB,UAAM,YAAY,mCAAY;AACrB,aAAA,KAAK,MAAM,KAAK,IAAI;AAAA,IAC7B,GAFkB;AAIlB,cAAU,MAAM;AACd,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AAC/B,eAAA,aAAa,CAAC,YAAoB;AACzC,sBAAc,OAAO;AAAA,MAAA,CACtB;AAAA,IAAA,CACF;;;;;;;;;;;;"}

79
comfy/web/assets/ServerStartView-CqRVtr1h.js generated vendored Normal file
View File

@ -0,0 +1,79 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, aD as useI18n, r as ref, o as onMounted, g as openBlock, h as createElementBlock, A as createBaseVNode, aw as createTextVNode, a6 as toDisplayString, z as unref, j as createCommentVNode, i as createVNode, D as script, bM as BaseTerminal, P as pushScopeId, Q as popScopeId, bN as electronAPI, _ as _export_sfc } from "./index-CoOvI8ZH.js";
import { P as ProgressStatus } from "./index-BppSBmxJ.js";
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-f5429be7"), n = n(), popScopeId(), n), "_withScopeId");
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
const _hoisted_2 = { class: "text-2xl font-bold" };
const _hoisted_3 = { key: 0 };
const _hoisted_4 = {
key: 0,
class: "flex items-center my-4 gap-2"
};
const _sfc_main = /* @__PURE__ */ defineComponent({
__name: "ServerStartView",
setup(__props) {
const electron = electronAPI();
const { t } = useI18n();
const status = ref(ProgressStatus.INITIAL_STATE);
const electronVersion = ref("");
let xterm;
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
status.value = newStatus;
xterm?.clear();
}, "updateProgress");
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
xterm = terminal;
useAutoSize(root, true, true);
electron.onLogMessage((message) => {
terminal.write(message);
});
terminal.options.cursorBlink = false;
terminal.options.disableStdin = true;
terminal.options.cursorInactiveStyle = "block";
}, "terminalCreated");
const reinstall = /* @__PURE__ */ __name(() => electron.reinstall(), "reinstall");
const reportIssue = /* @__PURE__ */ __name(() => {
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
}, "reportIssue");
const openLogs = /* @__PURE__ */ __name(() => electron.openLogsFolder(), "openLogs");
onMounted(async () => {
electron.sendReady();
electron.onProgressUpdate(updateProgress);
electronVersion.value = await electron.getElectronVersion();
});
return (_ctx, _cache) => {
return openBlock(), createElementBlock("div", _hoisted_1, [
createBaseVNode("h2", _hoisted_2, [
createTextVNode(toDisplayString(unref(t)(`serverStart.process.${status.value}`)) + " ", 1),
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
]),
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
createVNode(unref(script), {
icon: "pi pi-flag",
severity: "secondary",
label: unref(t)("serverStart.reportIssue"),
onClick: reportIssue
}, null, 8, ["label"]),
createVNode(unref(script), {
icon: "pi pi-file",
severity: "secondary",
label: unref(t)("serverStart.openLogs"),
onClick: openLogs
}, null, 8, ["label"]),
createVNode(unref(script), {
icon: "pi pi-refresh",
label: unref(t)("serverStart.reinstall"),
onClick: reinstall
}, null, 8, ["label"])
])) : createCommentVNode("", true),
createVNode(BaseTerminal, { onCreated: terminalCreated })
]);
};
}
});
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-f5429be7"]]);
export {
ServerStartView as default
};
//# sourceMappingURL=ServerStartView-CqRVtr1h.js.map

1
comfy/web/assets/ServerStartView-CqRVtr1h.js.map generated vendored Normal file
View File

@ -0,0 +1 @@
{"version":3,"file":"ServerStartView-CqRVtr1h.js","sources":["../../src/views/ServerStartView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">\n {{ t(`serverStart.process.${status}`) }}\n <span v-if=\"status === ProgressStatus.ERROR\">\n v{{ electronVersion }}\n </span>\n </h2>\n <div\n v-if=\"status === ProgressStatus.ERROR\"\n class=\"flex items-center my-4 gap-2\"\n >\n <Button\n icon=\"pi pi-flag\"\n severity=\"secondary\"\n :label=\"t('serverStart.reportIssue')\"\n @click=\"reportIssue\"\n />\n <Button\n icon=\"pi pi-file\"\n severity=\"secondary\"\n :label=\"t('serverStart.openLogs')\"\n @click=\"openLogs\"\n />\n <Button\n icon=\"pi pi-refresh\"\n :label=\"t('serverStart.reinstall')\"\n @click=\"reinstall\"\n />\n </div>\n <BaseTerminal @created=\"terminalCreated\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { ref, onMounted, Ref } from 'vue'\nimport BaseTerminal from '@/components/bottomPanel/tabs/terminal/BaseTerminal.vue'\nimport { ProgressStatus } from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\nimport type { useTerminal } from '@/hooks/bottomPanelTabs/useTerminal'\nimport { Terminal } from '@xterm/xterm'\nimport { useI18n } from 'vue-i18n'\n\nconst electron = electronAPI()\nconst { t } = useI18n()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst electronVersion = ref<string>('')\nlet xterm: Terminal | undefined\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n xterm?.clear()\n}\n\nconst terminalCreated = (\n { terminal, useAutoSize }: ReturnType<typeof useTerminal>,\n root: Ref<HTMLElement>\n) => {\n xterm = terminal\n\n useAutoSize(root, true, true)\n electron.onLogMessage((message: string) => {\n terminal.write(message)\n })\n\n terminal.options.cursorBlink = false\n terminal.options.disableStdin = true\n terminal.options.cursorInactiveStyle = 'block'\n}\n\nconst reinstall = () => electron.reinstall()\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\nconst openLogs = () => electron.openLogsFolder()\n\nonMounted(async () => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electronVersion.value = await electron.getElectronVersion()\n})\n</script>\n\n<style scoped>\n:deep(.xterm-helper-textarea) {\n /* Hide this as it moves all over when uv is running */\n display: none;\n}\n</style>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AA8CA,UAAM,WAAW;AACX,UAAA,EAAE,MAAM;AAER,UAAA,SAAS,IAAoB,eAAe,aAAa;AACzD,UAAA,kBAAkB,IAAY,EAAE;AAClC,QAAA;AAEJ,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AACf,aAAO,MAAM;AAAA,IAAA,GAFQ;AAKvB,UAAM,kBAAkB,wBACtB,EAAE,UAAU,YAAA,GACZ,SACG;AACK,cAAA;AAEI,kBAAA,MAAM,MAAM,IAAI;AACnB,eAAA,aAAa,CAAC,YAAoB;AACzC,iBAAS,MAAM,OAAO;AAAA,MAAA,CACvB;AAED,eAAS,QAAQ,cAAc;AAC/B,eAAS,QAAQ,eAAe;AAChC,eAAS,QAAQ,sBAAsB;AAAA,IAAA,GAbjB;AAgBlB,UAAA,YAAY,6BAAM,SAAS,aAAf;AAClB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAGd,UAAA,WAAW,6BAAM,SAAS,kBAAf;AAEjB,cAAU,YAAY;AACpB,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AACxB,sBAAA,QAAQ,MAAM,SAAS,mBAAmB;AAAA,IAAA,CAC3D;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

5
comfy/web/assets/ServerStartView-Djq8v91B.css generated vendored Normal file
View File

@ -0,0 +1,5 @@
[data-v-f5429be7] .xterm-helper-textarea {
/* Hide this as it moves all over when uv is running */
display: none;
}

View File

@ -1,4 +1,11 @@
<<<<<<<< HEAD:comfy/web/assets/WelcomeView-CsOBuA3U.js
import { d as defineComponent, g as openBlock, h as createElementBlock, A as createBaseVNode, a4 as toDisplayString, i as createVNode, z as unref, D as script, _ as _export_sfc } from "./index-bi78Y1IN.js";
========
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { d as defineComponent, g as openBlock, h as createElementBlock, A as createBaseVNode, a6 as toDisplayString, i as createVNode, z as unref, D as script, P as pushScopeId, Q as popScopeId, _ as _export_sfc } from "./index-CoOvI8ZH.js";
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-12b8b11b"), n = n(), popScopeId(), n), "_withScopeId");
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/WelcomeView-C4D1cggT.js
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
const _hoisted_2 = { class: "flex flex-col items-center justify-center gap-8 p-8" };
const _hoisted_3 = { class: "animated-gradient-text text-glow select-none" };
@ -27,4 +34,8 @@ const WelcomeView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-
export {
WelcomeView as default
};
<<<<<<<< HEAD:comfy/web/assets/WelcomeView-CsOBuA3U.js
//# sourceMappingURL=WelcomeView-CsOBuA3U.js.map
========
//# sourceMappingURL=WelcomeView-C4D1cggT.js.map
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/WelcomeView-C4D1cggT.js

1
comfy/web/assets/WelcomeView-C4D1cggT.js.map generated vendored Normal file
View File

@ -0,0 +1 @@
{"version":3,"file":"WelcomeView-C4D1cggT.js","sources":[],"sourcesContent":[],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}

52310
comfy/web/assets/index-Ba7IybyO.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-Ba7IybyO.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

106
comfy/web/assets/index-BppSBmxJ.js generated vendored Normal file
View File

@ -0,0 +1,106 @@
const IPC_CHANNELS = {
LOADING_PROGRESS: "loading-progress",
IS_PACKAGED: "is-packaged",
RENDERER_READY: "renderer-ready",
RESTART_APP: "restart-app",
REINSTALL: "reinstall",
LOG_MESSAGE: "log-message",
OPEN_DIALOG: "open-dialog",
DOWNLOAD_PROGRESS: "download-progress",
START_DOWNLOAD: "start-download",
PAUSE_DOWNLOAD: "pause-download",
RESUME_DOWNLOAD: "resume-download",
CANCEL_DOWNLOAD: "cancel-download",
DELETE_MODEL: "delete-model",
GET_ALL_DOWNLOADS: "get-all-downloads",
GET_ELECTRON_VERSION: "get-electron-version",
SEND_ERROR_TO_SENTRY: "send-error-to-sentry",
GET_BASE_PATH: "get-base-path",
GET_MODEL_CONFIG_PATH: "get-model-config-path",
OPEN_PATH: "open-path",
OPEN_LOGS_PATH: "open-logs-path",
OPEN_DEV_TOOLS: "open-dev-tools",
TERMINAL_WRITE: "execute-terminal-command",
TERMINAL_RESIZE: "resize-terminal",
TERMINAL_RESTORE: "restore-terminal",
TERMINAL_ON_OUTPUT: "terminal-output",
IS_FIRST_TIME_SETUP: "is-first-time-setup",
GET_SYSTEM_PATHS: "get-system-paths",
VALIDATE_INSTALL_PATH: "validate-install-path",
VALIDATE_COMFYUI_SOURCE: "validate-comfyui-source",
SHOW_DIRECTORY_PICKER: "show-directory-picker",
INSTALL_COMFYUI: "install-comfyui"
};
var ProgressStatus = /* @__PURE__ */ ((ProgressStatus2) => {
ProgressStatus2["INITIAL_STATE"] = "initial-state";
ProgressStatus2["PYTHON_SETUP"] = "python-setup";
ProgressStatus2["STARTING_SERVER"] = "starting-server";
ProgressStatus2["READY"] = "ready";
ProgressStatus2["ERROR"] = "error";
return ProgressStatus2;
})(ProgressStatus || {});
const ProgressMessages = {
[
"initial-state"
/* INITIAL_STATE */
]: "Loading...",
[
"python-setup"
/* PYTHON_SETUP */
]: "Setting up Python Environment...",
[
"starting-server"
/* STARTING_SERVER */
]: "Starting ComfyUI server...",
[
"ready"
/* READY */
]: "Finishing...",
[
"error"
/* ERROR */
]: "Was not able to start ComfyUI. Please check the logs for more details. You can open it from the Help menu. Please report issues to: https://forum.comfy.org"
};
const ELECTRON_BRIDGE_API = "electronAPI";
const SENTRY_URL_ENDPOINT = "https://942cadba58d247c9cab96f45221aa813@o4507954455314432.ingest.us.sentry.io/4508007940685824";
const MigrationItems = [
{
id: "user_files",
label: "User Files",
description: "Settings and user-created workflows"
},
{
id: "models",
label: "Models",
description: "Reference model files from existing ComfyUI installations. (No copy)"
}
// TODO: Decide whether we want to auto-migrate custom nodes, and install their dependencies.
// huchenlei: This is a very essential thing for migration experience.
// {
// id: 'custom_nodes',
// label: 'Custom Nodes',
// description: 'Reference custom node files from existing ComfyUI installations. (No copy)',
// },
];
const DEFAULT_SERVER_ARGS = {
/** The host to use for the ComfyUI server. */
host: "127.0.0.1",
/** The port to use for the ComfyUI server. */
port: 8e3,
// Extra arguments to pass to the ComfyUI server.
extraServerArgs: {}
};
var DownloadStatus = /* @__PURE__ */ ((DownloadStatus2) => {
DownloadStatus2["PENDING"] = "pending";
DownloadStatus2["IN_PROGRESS"] = "in_progress";
DownloadStatus2["COMPLETED"] = "completed";
DownloadStatus2["PAUSED"] = "paused";
DownloadStatus2["ERROR"] = "error";
DownloadStatus2["CANCELLED"] = "cancelled";
return DownloadStatus2;
})(DownloadStatus || {});
export {
MigrationItems as M,
ProgressStatus as P
};
//# sourceMappingURL=index-BppSBmxJ.js.map

1
comfy/web/assets/index-BppSBmxJ.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-CoOvI8ZH.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

50
comfy/web/assets/index-D4DWQPPQ.js generated vendored Normal file
View File

@ -0,0 +1,50 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { cb as script$2, A as createBaseVNode, g as openBlock, h as createElementBlock, m as mergeProps } from "./index-CoOvI8ZH.js";
var script$1 = {
name: "BarsIcon",
"extends": script$2
};
var _hoisted_1$1 = /* @__PURE__ */ createBaseVNode("path", {
"fill-rule": "evenodd",
"clip-rule": "evenodd",
d: "M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2$1 = [_hoisted_1$1];
function render$1(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_2$1, 16);
}
__name(render$1, "render$1");
script$1.render = render$1;
var script = {
name: "PlusIcon",
"extends": script$2
};
var _hoisted_1 = /* @__PURE__ */ createBaseVNode("path", {
d: "M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z",
fill: "currentColor"
}, null, -1);
var _hoisted_2 = [_hoisted_1];
function render(_ctx, _cache, $props, $setup, $data, $options) {
return openBlock(), createElementBlock("svg", mergeProps({
width: "14",
height: "14",
viewBox: "0 0 14 14",
fill: "none",
xmlns: "http://www.w3.org/2000/svg"
}, _ctx.pti()), _hoisted_2, 16);
}
__name(render, "render");
script.render = render;
export {
script as a,
script$1 as s
};
//# sourceMappingURL=index-D4DWQPPQ.js.map

1
comfy/web/assets/index-D4DWQPPQ.js.map generated vendored Normal file
View File

@ -0,0 +1 @@
{"version":3,"file":"index-D4DWQPPQ.js","sources":["../../node_modules/@primevue/icons/bars/index.mjs","../../node_modules/@primevue/icons/plus/index.mjs"],"sourcesContent":["import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'BarsIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n \"fill-rule\": \"evenodd\",\n \"clip-rule\": \"evenodd\",\n d: \"M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n","import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'PlusIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n d: \"M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n"],"names":["script","BaseIcon","_hoisted_1","createElementVNode","_hoisted_2","render"],"mappings":";;;AAGG,IAACA,WAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWC;AACb;AAEA,IAAIC,eAA0BC,gCAAmB,QAAQ;AAAA,EACvD,aAAa;AAAA,EACb,aAAa;AAAA,EACb,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAIC,eAAa,CAACF,YAAU;AAC5B,SAASG,SAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAGD,cAAY,EAAE;AAChC;AARSC;AAUTL,SAAO,SAASK;ACtBb,IAAC,SAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWJ;AACb;AAEA,IAAI,aAA0BE,gCAAmB,QAAQ;AAAA,EACvD,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAI,aAAa,CAAC,UAAU;AAC5B,SAAS,OAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAG,YAAY,EAAE;AAChC;AARS;AAUT,OAAO,SAAS;","x_google_ignoreList":[0,1]}

View File

@ -1,7 +1,12 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
<<<<<<<< HEAD:comfy/web/assets/index-ftUEqmu1.js
import { b6 as script$s, g as openBlock, h as createElementBlock, m as mergeProps, A as createBaseVNode, B as BaseStyle, P as script$t, a4 as toDisplayString, $ as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, J as resolveDynamicComponent, c9 as script$u, l as resolveComponent, C as normalizeClass, au as createSlots, y as withCtx, bw as script$v, bm as script$w, N as Fragment, O as renderList, av as createTextVNode, bc as setAttribute, ba as normalizeProps, p as renderSlot, j as createCommentVNode, ab as script$x, a2 as equals, b7 as script$y, bR as script$z, ca as getFirstFocusableElement, af as OverlayEventBus, a6 as getVNodeProp, ae as resolveFieldData, cb as invokeElementMethod, a0 as getAttribute, cc as getNextElementSibling, W as getOuterWidth, cd as getPreviousElementSibling, D as script$A, aq as script$B, Z as script$C, b9 as script$E, aa as isNotEmpty, bi as withModifiers, U as getOuterHeight, ac as UniqueComponentId, ce as _default, ad as ZIndex, a1 as focus, ah as addStyle, aj as absolutePosition, ak as ConnectedOverlayScrollHandler, al as isTouchDevice, cf as FilterOperator, ap as script$F, cg as FocusTrap, i as createVNode, at as Transition, ch as withKeys, ci as getIndex, s as script$H, cj as isClickable, ck as clearSelection, cl as localeComparator, cm as sort, cn as FilterService, c5 as FilterMatchMode, R as findSingle, bL as findIndexInList, bM as find, co as exportCSV, V as getOffset, cp as getHiddenElementOuterWidth, cq as getHiddenElementOuterHeight, cr as reorderArray, cs as removeClass, ct as addClass, ag as isEmpty, ao as script$I, ar as script$J } from "./index-bi78Y1IN.js";
import { s as script$D, a as script$G } from "./index-bCeMLtLM.js";
========
import { cb as script$s, A as createBaseVNode, g as openBlock, h as createElementBlock, m as mergeProps, B as BaseStyle, R as script$t, a6 as toDisplayString, a1 as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, J as resolveDynamicComponent, cc as script$u, l as resolveComponent, C as normalizeClass, av as createSlots, y as withCtx, by as script$v, bo as script$w, N as Fragment, O as renderList, aw as createTextVNode, be as setAttribute, ad as UniqueComponentId, bc as normalizeProps, p as renderSlot, j as createCommentVNode, a4 as equals, b8 as script$x, bU as script$y, cd as getFirstFocusableElement, ag as OverlayEventBus, a8 as getVNodeProp, af as resolveFieldData, ce as invokeElementMethod, a2 as getAttribute, cf as getNextElementSibling, Y as getOuterWidth, cg as getPreviousElementSibling, D as script$z, ar as script$A, a0 as script$B, bb as script$D, ac as isNotEmpty, bk as withModifiers, W as getOuterHeight, ch as _default, ae as ZIndex, a3 as focus, ai as addStyle, ak as absolutePosition, al as ConnectedOverlayScrollHandler, am as isTouchDevice, ci as FilterOperator, aq as script$E, cj as FocusTrap, i as createVNode, au as Transition, ck as withKeys, cl as getIndex, s as script$G, cm as isClickable, cn as clearSelection, co as localeComparator, cp as sort, cq as FilterService, c6 as FilterMatchMode, V as findSingle, bO as findIndexInList, bP as find, cr as exportCSV, X as getOffset, cs as getHiddenElementOuterWidth, ct as getHiddenElementOuterHeight, cu as reorderArray, cv as getWindowScrollTop, cw as removeClass, cx as addClass, ah as isEmpty, ap as script$H, as as script$I } from "./index-CoOvI8ZH.js";
import { s as script$C, a as script$F } from "./index-D4DWQPPQ.js";
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/index-DK6Kev7f.js
var script$r = {
name: "ArrowDownIcon",
"extends": script$s
@ -8829,4 +8834,8 @@ export {
script$d as a,
script as s
};
<<<<<<<< HEAD:comfy/web/assets/index-ftUEqmu1.js
//# sourceMappingURL=index-ftUEqmu1.js.map
========
//# sourceMappingURL=index-DK6Kev7f.js.map
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/index-DK6Kev7f.js

1
comfy/web/assets/index-DK6Kev7f.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

90
comfy/web/assets/serverConfigStore-cctR8PGG.js generated vendored Normal file
View File

@ -0,0 +1,90 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { cB as defineStore, r as ref, q as computed } from "./index-CoOvI8ZH.js";
const useServerConfigStore = defineStore("serverConfig", () => {
const serverConfigById = ref({});
const serverConfigs = computed(() => {
return Object.values(serverConfigById.value);
});
const modifiedConfigs = computed(
() => {
return serverConfigs.value.filter((config) => {
return config.initialValue !== config.value;
});
}
);
const revertChanges = /* @__PURE__ */ __name(() => {
for (const config of modifiedConfigs.value) {
config.value = config.initialValue;
}
}, "revertChanges");
const serverConfigsByCategory = computed(() => {
return serverConfigs.value.reduce(
(acc, config) => {
const category = config.category?.[0] ?? "General";
acc[category] = acc[category] || [];
acc[category].push(config);
return acc;
},
{}
);
});
const serverConfigValues = computed(() => {
return Object.fromEntries(
serverConfigs.value.map((config) => {
return [
config.id,
config.value === config.defaultValue || config.value === null || config.value === void 0 ? void 0 : config.value
];
})
);
});
const launchArgs = computed(() => {
const args = Object.assign(
{},
...serverConfigs.value.map((config) => {
if (config.value === config.defaultValue || config.value === null || config.value === void 0) {
return {};
}
return config.getValue ? config.getValue(config.value) : { [config.id]: config.value };
})
);
return Object.fromEntries(
Object.entries(args).map(([key, value]) => {
if (value === true) {
return [key, ""];
}
return [key, value.toString()];
})
);
});
const commandLineArgs = computed(() => {
return Object.entries(launchArgs.value).map(([key, value]) => [`--${key}`, value]).flat().filter((arg) => arg !== "").join(" ");
});
function loadServerConfig(configs, values) {
for (const config of configs) {
const value = values[config.id] ?? config.defaultValue;
serverConfigById.value[config.id] = {
...config,
value,
initialValue: value
};
}
}
__name(loadServerConfig, "loadServerConfig");
return {
serverConfigById,
serverConfigs,
modifiedConfigs,
serverConfigsByCategory,
serverConfigValues,
launchArgs,
commandLineArgs,
revertChanges,
loadServerConfig
};
});
export {
useServerConfigStore as u
};
//# sourceMappingURL=serverConfigStore-cctR8PGG.js.map

1
comfy/web/assets/serverConfigStore-cctR8PGG.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,10 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
<<<<<<<< HEAD:comfy/web/assets/userSelection-C-TGdC-2.js
import { aY as api, bX as $el } from "./index-bi78Y1IN.js";
========
import { aZ as api, bY as $el } from "./index-CoOvI8ZH.js";
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/userSelection-C6c30qSU.js
function createSpinner() {
const div = document.createElement("div");
div.innerHTML = `<div class="lds-ring"><div></div><div></div><div></div><div></div></div>`;
@ -126,4 +130,8 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
<<<<<<<< HEAD:comfy/web/assets/userSelection-C-TGdC-2.js
//# sourceMappingURL=userSelection-C-TGdC-2.js.map
========
//# sourceMappingURL=userSelection-C6c30qSU.js.map
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/userSelection-C6c30qSU.js

5
comfy/web/assets/userSelection-C6c30qSU.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,10 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
<<<<<<<< HEAD:comfy/web/assets/widgetInputs-DjCs8Yab.js
import { e as LGraphNode, c as app, c2 as applyTextReplacements, c1 as ComfyWidgets, c4 as addValueControlWidgets, k as LiteGraph } from "./index-bi78Y1IN.js";
========
import { e as LGraphNode, c as app, c3 as applyTextReplacements, c2 as ComfyWidgets, c5 as addValueControlWidgets, k as LiteGraph } from "./index-CoOvI8ZH.js";
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/widgetInputs-CRPRgKEi.js
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = [
"STRING",
@ -759,4 +763,8 @@ export {
mergeIfValid,
setWidgetConfig
};
<<<<<<<< HEAD:comfy/web/assets/widgetInputs-DjCs8Yab.js
//# sourceMappingURL=widgetInputs-DjCs8Yab.js.map
========
//# sourceMappingURL=widgetInputs-CRPRgKEi.js.map
>>>>>>>> 0fd4e6c7787daa922cb0d9bd1241b276b39ce2c5:comfy/web/assets/widgetInputs-CRPRgKEi.js

Some files were not shown because too many files have changed in this diff Show More