mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
38bcd9fcbd
21
.github/workflows/stale-issues.yml
vendored
Normal file
21
.github/workflows/stale-issues.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
name: 'Close stale issues'
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
# Run daily at 430 am PT
|
||||||
|
- cron: '30 11 * * *'
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
|
||||||
|
days-before-stale: 30
|
||||||
|
days-before-close: 7
|
||||||
|
stale-issue-label: 'Stale'
|
||||||
|
only-labels: 'User Support'
|
||||||
|
exempt-all-assignees: true
|
||||||
|
exempt-all-milestones: true
|
||||||
@ -1,7 +1,11 @@
|
|||||||
from aiohttp import web
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ....cmd.folder_paths import models_dir, user_directory, output_directory
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
from ...services.file_service import FileService
|
from ...services.file_service import FileService
|
||||||
|
from ....cmd.folder_paths import models_dir, user_directory, output_directory
|
||||||
|
|
||||||
|
|
||||||
class InternalRoutes:
|
class InternalRoutes:
|
||||||
'''
|
'''
|
||||||
@ -10,6 +14,7 @@ class InternalRoutes:
|
|||||||
Check README.md for more information.
|
Check README.md for more information.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||||
self._app: Optional[web.Application] = None
|
self._app: Optional[web.Application] = None
|
||||||
@ -31,6 +36,10 @@ class InternalRoutes:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return web.json_response({"error": str(e)}, status=500)
|
return web.json_response({"error": str(e)}, status=500)
|
||||||
|
|
||||||
|
@self.routes.get('/logs')
|
||||||
|
async def get_logs(request):
|
||||||
|
# todo: applications really shouldn't serve logs like this
|
||||||
|
return web.json_response({})
|
||||||
|
|
||||||
def get_app(self):
|
def get_app(self):
|
||||||
if self._app is None:
|
if self._app is None:
|
||||||
|
|||||||
@ -7,13 +7,13 @@ Use this instead of cli_args to import the args:
|
|||||||
|
|
||||||
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
|
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
|
||||||
"""
|
"""
|
||||||
|
import ctypes
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
import ctypes
|
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
@ -110,6 +110,15 @@ def _create_tracer():
|
|||||||
return trace.get_tracer(args.otel_service_name)
|
return trace.get_tracer(args.otel_service_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_logging():
|
||||||
|
logging_level = logging.INFO
|
||||||
|
if args.verbose:
|
||||||
|
logging_level = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||||
|
|
||||||
|
|
||||||
|
_configure_logging()
|
||||||
_fix_pytorch_240()
|
_fix_pytorch_240()
|
||||||
tracer = _create_tracer()
|
tracer = _create_tracer()
|
||||||
__all__ = ["args", "tracer"]
|
__all__ = ["args", "tracer"]
|
||||||
|
|||||||
@ -25,13 +25,11 @@ from aiohttp import web
|
|||||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
from ..api_server.routes.internal.internal_routes import InternalRoutes
|
|
||||||
from ..model_filemanager import download_model, DownloadModelStatus
|
|
||||||
from .latent_preview_image_encoding import encode_preview_image
|
from .latent_preview_image_encoding import encode_preview_image
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
|
||||||
from .. import node_helpers
|
from .. import node_helpers
|
||||||
from .. import utils
|
from .. import utils
|
||||||
|
from ..api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from ..app.frontend_management import FrontendManager
|
from ..app.frontend_management import FrontendManager
|
||||||
from ..app.user_manager import UserManager
|
from ..app.user_manager import UserManager
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
@ -45,6 +43,8 @@ from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTy
|
|||||||
ExecutionStatus
|
ExecutionStatus
|
||||||
from ..digest import digest
|
from ..digest import digest
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
|
from ..model_filemanager import download_model, DownloadModelStatus
|
||||||
|
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
|
||||||
from ..nodes.package_typing import ExportedNodes
|
from ..nodes.package_typing import ExportedNodes
|
||||||
|
|
||||||
|
|
||||||
@ -60,6 +60,22 @@ async def send_socket_catch_exception(function, message):
|
|||||||
logging.warning("send error: {}".format(err))
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
|
||||||
|
def get_comfyui_version():
|
||||||
|
comfyui_version = "unknown"
|
||||||
|
repo_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
try:
|
||||||
|
import pygit2
|
||||||
|
repo = pygit2.Repository(repo_path)
|
||||||
|
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to get ComfyUI version: {e}")
|
||||||
|
return comfyui_version.strip()
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
response: web.Response = await handler(request)
|
response: web.Response = await handler(request)
|
||||||
@ -104,7 +120,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
|
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
|
||||||
self.loop: AbstractEventLoop = loop
|
self.loop: AbstractEventLoop = loop
|
||||||
self.messages: asyncio.Queue = asyncio.Queue()
|
self.messages: asyncio.Queue = asyncio.Queue()
|
||||||
self.client_session:Optional[aiohttp.ClientSession] = None
|
self.client_session: Optional[aiohttp.ClientSession] = None
|
||||||
self.number: int = 0
|
self.number: int = 0
|
||||||
self.port: int = 8188
|
self.port: int = 8188
|
||||||
self._external_address: Optional[str] = None
|
self._external_address: Optional[str] = None
|
||||||
@ -418,16 +434,20 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_system_stats(request):
|
async def system_stats(request):
|
||||||
device = model_management.get_torch_device()
|
device = get_torch_device()
|
||||||
device_name = model_management.get_torch_device_name(device)
|
device_name = get_torch_device_name(device)
|
||||||
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
|
||||||
vram_free, torch_vram_free = model_management.get_free_memory(device, torch_free_too=True)
|
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
|
||||||
|
|
||||||
system_stats = {
|
system_stats = {
|
||||||
"system": {
|
"system": {
|
||||||
"os": os.name,
|
"os": os.name,
|
||||||
|
"comfyui_version": get_comfyui_version(),
|
||||||
"python_version": sys.version,
|
"python_version": sys.version,
|
||||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
|
"pytorch_version": torch_version,
|
||||||
|
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||||
|
"argv": sys.argv
|
||||||
},
|
},
|
||||||
"devices": [
|
"devices": [
|
||||||
{
|
{
|
||||||
@ -611,7 +631,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
url = data.get('url')
|
url = data.get('url')
|
||||||
model_directory = data.get('model_directory')
|
model_directory = data.get('model_directory')
|
||||||
model_filename = data.get('model_filename')
|
model_filename = data.get('model_filename')
|
||||||
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
|
||||||
|
|
||||||
if not url or not model_directory or not model_filename:
|
if not url or not model_directory or not model_filename:
|
||||||
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
|
||||||
@ -776,7 +796,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self._external_address = value
|
self._external_address = value
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
|
|||||||
@ -33,8 +33,6 @@ from .cldm import cldm, mmdit
|
|||||||
from .ldm import hydit
|
from .ldm import hydit
|
||||||
from .ldm.cascade import controlnet as cascade_controlnet
|
from .ldm.cascade import controlnet as cascade_controlnet
|
||||||
from .ldm.flux import controlnet as controlnet_flux
|
from .ldm.flux import controlnet as controlnet_flux
|
||||||
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
|
|
||||||
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
|
|
||||||
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
|
|
||||||
@ -509,7 +507,12 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = controlnet_flux.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
num_union_modes = 0
|
||||||
|
union_cnet = "controlnet_mode_embedder.weight"
|
||||||
|
if union_cnet in new_sd:
|
||||||
|
num_union_modes = new_sd[union_cnet].shape[0]
|
||||||
|
|
||||||
|
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = latent_formats.Flux()
|
latent_format = latent_formats.Flux()
|
||||||
@ -519,10 +522,6 @@ def load_controlnet_flux_instantx(sd):
|
|||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if "controlnet_mode_embedder.weight" in controlnet_data:
|
|
||||||
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
|
|
||||||
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
|
|
||||||
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
|
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
|
|||||||
@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
|||||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||||
)
|
)
|
||||||
del abs_x
|
|
||||||
|
|
||||||
return sign.to(dtype=dtype)
|
return sign
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
for i in range(0, value.shape[0], slice_size):
|
||||||
|
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
|
||||||
|
return output
|
||||||
|
|
||||||
return value.to(dtype=dtype)
|
return value.to(dtype=dtype)
|
||||||
|
|||||||
@ -1,20 +1,19 @@
|
|||||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
from torch import Tensor, nn
|
from typing import Never
|
||||||
|
|
||||||
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
from .layers import (timestep_embedding)
|
||||||
MLPEmbedder, SingleStreamBlock,
|
|
||||||
timestep_embedding)
|
|
||||||
|
|
||||||
from .model import Flux
|
from .model import Flux
|
||||||
from .. import common_dit
|
from .. import common_dit
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFlux(Flux):
|
class ControlNetFlux(Flux):
|
||||||
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
self.main_model_double = 19
|
self.main_model_double = 19
|
||||||
@ -29,6 +28,11 @@ class ControlNetFlux(Flux):
|
|||||||
for _ in range(self.params.depth_single_blocks):
|
for _ in range(self.params.depth_single_blocks):
|
||||||
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.num_union_modes = num_union_modes
|
||||||
|
self.controlnet_mode_embedder = None
|
||||||
|
if self.num_union_modes > 0:
|
||||||
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
@ -61,6 +65,7 @@ class ControlNetFlux(Flux):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
control_type: Tensor | list[Never] | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@ -79,6 +84,11 @@ class ControlNetFlux(Flux):
|
|||||||
vec = vec + self.vector_in(y)
|
vec = vec + self.vector_in(y)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
||||||
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
||||||
|
txt = torch.cat([control_cond, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
@ -137,4 +147,4 @@ class ControlNetFlux(Flux):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
||||||
|
|||||||
@ -1,309 +0,0 @@
|
|||||||
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import numbers
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.normalization import AdaLayerNormContinuous
|
|
||||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
|
||||||
from diffusers.utils.import_utils import is_torch_version
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from ...ldm import common_dit
|
|
||||||
from .layers import timestep_embedding
|
|
||||||
from .model import Flux
|
|
||||||
|
|
||||||
if is_torch_version(">=", "2.1.0"):
|
|
||||||
LayerNorm = nn.LayerNorm
|
|
||||||
else:
|
|
||||||
# Has optional bias parameter compared to torch layer norm
|
|
||||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if isinstance(dim, numbers.Integral):
|
|
||||||
dim = (dim,)
|
|
||||||
|
|
||||||
self.dim = torch.Size(dim)
|
|
||||||
|
|
||||||
if elementwise_affine:
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
|
||||||
else:
|
|
||||||
self.weight = None
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class FluxUnionControlNetModeEmbedder(nn.Module):
|
|
||||||
def __init__(self, num_mode, out_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.mode_embber = nn.Embedding(num_mode, out_channels)
|
|
||||||
self.norm = nn.LayerNorm(out_channels, eps=1e-6)
|
|
||||||
self.fc = nn.Linear(out_channels, out_channels)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_emb = self.mode_embber(x)
|
|
||||||
x_emb = self.norm(x_emb)
|
|
||||||
x_emb = self.fc(x_emb)
|
|
||||||
x_emb = x_emb[:, 0]
|
|
||||||
return x_emb
|
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
for p in module.parameters():
|
|
||||||
nn.init.zeros_(p)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
# YiYi to-do: refactor rope related functions/classes
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|
||||||
class FluxUnionControlNetInputEmbedder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2):
|
|
||||||
super().__init__()
|
|
||||||
self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels))
|
|
||||||
self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6)
|
|
||||||
self.fc = nn.Linear(out_channels, out_channels)
|
|
||||||
self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels))
|
|
||||||
|
|
||||||
""" self.single_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
SingleStreamBlock(
|
|
||||||
out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations
|
|
||||||
)
|
|
||||||
for i in range(2)
|
|
||||||
]
|
|
||||||
) """
|
|
||||||
self.single_transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FluxSingleTransformerBlock(
|
|
||||||
dim=out_channels,
|
|
||||||
num_attention_heads=num_attention_heads,
|
|
||||||
attention_head_dim=attention_head_dim,
|
|
||||||
)
|
|
||||||
for i in range(depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out = zero_module(nn.Linear(out_channels, out_channels))
|
|
||||||
|
|
||||||
def forward(self, x, mode_emb):
|
|
||||||
mode_token = self.emb_embedder(mode_emb)[:, None]
|
|
||||||
x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb))
|
|
||||||
hidden_states = torch.cat([mode_token, x_emb], dim=1)
|
|
||||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
||||||
hidden_states = block(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
temb=mode_emb,
|
|
||||||
)
|
|
||||||
hidden_states = self.out(hidden_states)
|
|
||||||
res = hidden_states[:, 1:]
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class InstantXControlNetFlux(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
|
||||||
kwargs["depth"] = 0
|
|
||||||
kwargs["depth_single_blocks"] = 0
|
|
||||||
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FluxTransformerBlock(
|
|
||||||
dim=self.hidden_size,
|
|
||||||
num_attention_heads=24,
|
|
||||||
attention_head_dim=128,
|
|
||||||
).to(dtype=dtype)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.single_transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FluxSingleTransformerBlock(
|
|
||||||
dim=self.hidden_size,
|
|
||||||
num_attention_heads=24,
|
|
||||||
attention_head_dim=128,
|
|
||||||
).to(dtype=dtype)
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.require_vae = True
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(len(self.transformer_blocks)):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(len(self.single_transformer_blocks)):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_single_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
# TODO support both union and unimodal
|
|
||||||
self.union = True # num_mode is not None
|
|
||||||
num_mode = 10
|
|
||||||
if self.union:
|
|
||||||
self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype)
|
|
||||||
self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype)
|
|
||||||
self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype)
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
|
||||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
|
||||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
|
||||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
|
||||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def set_hint_latents(self, hint_latents):
|
|
||||||
vae_shift_factor = 0.1159
|
|
||||||
vae_scaling_factor = 0.3611
|
|
||||||
num_channels_latents = self.in_channels // 4
|
|
||||||
hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor
|
|
||||||
|
|
||||||
height, width = hint_latents.shape[2:]
|
|
||||||
hint_latents = self._pack_latents(
|
|
||||||
hint_latents,
|
|
||||||
hint_latents.shape[0],
|
|
||||||
num_channels_latents,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
)
|
|
||||||
self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
controlnet_mode=None
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
batch_size = img.shape[0]
|
|
||||||
|
|
||||||
img = self.img_in(img)
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
|
||||||
vec.add_(self.vector_in(y))
|
|
||||||
|
|
||||||
if self.union:
|
|
||||||
if controlnet_mode is None:
|
|
||||||
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
|
||||||
controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device)
|
|
||||||
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
|
||||||
vec = vec + emb_controlnet_mode
|
|
||||||
img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode)
|
|
||||||
else:
|
|
||||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
|
||||||
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
if self.union:
|
|
||||||
token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None]
|
|
||||||
token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1)
|
|
||||||
txt = torch.cat([token_controlnet_mode, txt], dim=1)
|
|
||||||
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
|
||||||
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
img = torch.cat([txt, img], dim=1)
|
|
||||||
|
|
||||||
single_block_res_samples = ()
|
|
||||||
for block in self.single_transformer_blocks:
|
|
||||||
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
|
||||||
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
controlnet_single_block_res_samples = ()
|
|
||||||
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
|
||||||
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
|
||||||
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
|
||||||
|
|
||||||
n_single_blocks = 38
|
|
||||||
n_double_blocks = 19
|
|
||||||
|
|
||||||
# Expand controlnet_block_res_samples to match n_double_blocks
|
|
||||||
expanded_controlnet_block_res_samples = []
|
|
||||||
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
|
||||||
for i in range(n_double_blocks):
|
|
||||||
index = i // interval_control_double
|
|
||||||
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
|
||||||
|
|
||||||
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
|
||||||
expanded_controlnet_single_block_res_samples = []
|
|
||||||
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
|
||||||
for i in range(n_single_blocks):
|
|
||||||
index = i // interval_control_single
|
|
||||||
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input": expanded_controlnet_block_res_samples,
|
|
||||||
"output": expanded_controlnet_single_block_res_samples
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
height_control_image, width_control_image = hint.shape[2:]
|
|
||||||
num_channels_latents = self.in_channels // 4
|
|
||||||
hint = self._pack_latents(
|
|
||||||
hint,
|
|
||||||
hint.shape[0],
|
|
||||||
num_channels_latents,
|
|
||||||
height_control_image,
|
|
||||||
width_control_image,
|
|
||||||
)
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
|
||||||
@ -1,227 +0,0 @@
|
|||||||
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
||||||
|
|
||||||
import numbers
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
|
||||||
from diffusers.utils.import_utils import is_torch_version
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from .layers import timestep_embedding
|
|
||||||
from .model import Flux
|
|
||||||
from ..common_dit import pad_to_patch_size
|
|
||||||
|
|
||||||
if is_torch_version(">=", "2.1.0"):
|
|
||||||
LayerNorm = nn.LayerNorm
|
|
||||||
else:
|
|
||||||
# Has optional bias parameter compared to torch layer norm
|
|
||||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if isinstance(dim, numbers.Integral):
|
|
||||||
dim = (dim,)
|
|
||||||
|
|
||||||
self.dim = torch.Size(dim)
|
|
||||||
|
|
||||||
if elementwise_affine:
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
|
||||||
else:
|
|
||||||
self.weight = None
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
for p in module.parameters():
|
|
||||||
nn.init.zeros_(p)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
# YiYi to-do: refactor rope related functions/classes
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
|
||||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
||||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
||||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
||||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
||||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
||||||
|
|
||||||
|
|
||||||
class InstantXControlNetFluxFormat2(Flux):
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
|
||||||
kwargs["depth"] = 0
|
|
||||||
kwargs["depth_single_blocks"] = 0
|
|
||||||
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
|
||||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FluxTransformerBlock(
|
|
||||||
dim=self.hidden_size,
|
|
||||||
num_attention_heads=24,
|
|
||||||
attention_head_dim=128,
|
|
||||||
).to(dtype=dtype)
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.single_transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
FluxSingleTransformerBlock(
|
|
||||||
dim=self.hidden_size,
|
|
||||||
num_attention_heads=24,
|
|
||||||
attention_head_dim=128,
|
|
||||||
).to(dtype=dtype)
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.require_vae = True
|
|
||||||
# add ControlNet blocks
|
|
||||||
self.controlnet_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(len(self.transformer_blocks)):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
self.controlnet_single_blocks = nn.ModuleList([])
|
|
||||||
for _ in range(len(self.single_transformer_blocks)):
|
|
||||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
||||||
controlnet_block = zero_module(controlnet_block)
|
|
||||||
self.controlnet_single_blocks.append(controlnet_block)
|
|
||||||
|
|
||||||
# TODO support both union and unimodal
|
|
||||||
self.union = True # num_mode is not None
|
|
||||||
num_mode = 10
|
|
||||||
if self.union:
|
|
||||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
|
|
||||||
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
|
|
||||||
self.gradient_checkpointing = False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
|
||||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
|
||||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
|
||||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
|
||||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def forward_orig(
|
|
||||||
self,
|
|
||||||
img: Tensor,
|
|
||||||
img_ids: Tensor,
|
|
||||||
controlnet_cond: Tensor,
|
|
||||||
txt: Tensor,
|
|
||||||
txt_ids: Tensor,
|
|
||||||
timesteps: Tensor,
|
|
||||||
y: Tensor,
|
|
||||||
guidance: Tensor = None,
|
|
||||||
controlnet_mode=None
|
|
||||||
) -> Tensor:
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
||||||
|
|
||||||
batch_size = img.shape[0]
|
|
||||||
|
|
||||||
img = self.img_in(img)
|
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
|
||||||
if self.params.guidance_embed:
|
|
||||||
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
|
||||||
vec.add_(self.vector_in(y))
|
|
||||||
|
|
||||||
txt = self.txt_in(txt)
|
|
||||||
|
|
||||||
if self.union:
|
|
||||||
if controlnet_mode is None:
|
|
||||||
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
|
||||||
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
|
|
||||||
controlnet_mode = controlnet_mode.reshape([-1, 1])
|
|
||||||
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
|
||||||
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
|
|
||||||
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
|
||||||
|
|
||||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
|
||||||
|
|
||||||
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
||||||
pe = self.pe_embedder(ids)
|
|
||||||
|
|
||||||
block_res_samples = ()
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
|
||||||
block_res_samples = block_res_samples + (img,)
|
|
||||||
|
|
||||||
img = torch.cat([txt, img], dim=1)
|
|
||||||
|
|
||||||
single_block_res_samples = ()
|
|
||||||
for block in self.single_transformer_blocks:
|
|
||||||
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
|
||||||
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
|
||||||
|
|
||||||
controlnet_block_res_samples = ()
|
|
||||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
||||||
block_res_sample = controlnet_block(block_res_sample)
|
|
||||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
||||||
|
|
||||||
controlnet_single_block_res_samples = ()
|
|
||||||
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
|
||||||
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
|
||||||
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
|
||||||
|
|
||||||
n_single_blocks = 38
|
|
||||||
n_double_blocks = 19
|
|
||||||
|
|
||||||
# Expand controlnet_block_res_samples to match n_double_blocks
|
|
||||||
expanded_controlnet_block_res_samples = []
|
|
||||||
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
|
||||||
for i in range(n_double_blocks):
|
|
||||||
index = i // interval_control_double
|
|
||||||
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
|
||||||
|
|
||||||
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
|
||||||
expanded_controlnet_single_block_res_samples = []
|
|
||||||
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
|
||||||
for i in range(n_single_blocks):
|
|
||||||
index = i // interval_control_single
|
|
||||||
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input": expanded_controlnet_block_res_samples,
|
|
||||||
"output": expanded_controlnet_single_block_res_samples
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
|
||||||
bs, c, h, w = x.shape
|
|
||||||
patch_size = 2
|
|
||||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
|
||||||
|
|
||||||
height_control_image, width_control_image = hint.shape[2:]
|
|
||||||
num_channels_latents = self.in_channels // 4
|
|
||||||
hint = self._pack_latents(
|
|
||||||
hint,
|
|
||||||
hint.shape[0],
|
|
||||||
num_channels_latents,
|
|
||||||
height_control_image,
|
|
||||||
width_control_image,
|
|
||||||
)
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
||||||
|
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
||||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
||||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
||||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
|
||||||
@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
if layer > self.depth // 2:
|
if layer > self.depth // 2:
|
||||||
if controls is not None:
|
if controls is not None:
|
||||||
skip = skips.pop() + controls.pop()
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|||||||
@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
to = diffusers_keys[k]
|
to = diffusers_keys[k]
|
||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
|
||||||
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
@ -528,20 +529,40 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
elif patch_type == "glora":
|
elif patch_type == "glora":
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
old_glora = False
|
||||||
|
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
|
||||||
|
rank = v[0].shape[0]
|
||||||
|
old_glora = True
|
||||||
|
|
||||||
|
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
|
||||||
|
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
old_glora = False
|
||||||
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / rank
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
if old_glora:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
|
||||||
|
else:
|
||||||
|
if weight.dim() > 2:
|
||||||
|
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
else:
|
||||||
|
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
|
||||||
|
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
|
||||||
|
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -61,6 +61,7 @@ cpu_state = CPUState.GPU
|
|||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
torch_version = ""
|
||||||
try:
|
try:
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
@ -419,13 +420,11 @@ def offloaded_memory(loaded_models, device):
|
|||||||
return offloaded_mem
|
return offloaded_mem
|
||||||
|
|
||||||
|
|
||||||
def minimum_inference_memory():
|
WINDOWS = any(platform.win32_ver())
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
|
||||||
|
|
||||||
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
if WINDOWS:
|
||||||
if any(platform.win32_ver()):
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 # Windows is higher because of the shared vram issue
|
||||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 # Windows is higher because of the shared vram issue
|
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
@ -436,6 +435,10 @@ def extra_reserved_memory():
|
|||||||
return EXTRA_RESERVED_VRAM
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
|
|
||||||
|
def minimum_inference_memory():
|
||||||
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||||
with model_management_lock:
|
with model_management_lock:
|
||||||
return _unload_model_clones(model, unload_weights_only, force_unload)
|
return _unload_model_clones(model, unload_weights_only, force_unload)
|
||||||
@ -1119,7 +1122,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
return True
|
if WINDOWS or manual_cast:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False # weird linux behavior where fp32 is faster
|
||||||
|
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
|
|||||||
2628
comfy/web/assets/index-DkvOTKox.js → comfy/web/assets/index-BD-Ia1C4.js
generated
vendored
2628
comfy/web/assets/index-DkvOTKox.js → comfy/web/assets/index-BD-Ia1C4.js
generated
vendored
File diff suppressed because it is too large
Load Diff
1
comfy/web/assets/index-BD-Ia1C4.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-BD-Ia1C4.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
84783
comfy/web/assets/index-CaD4RONs.js → comfy/web/assets/index-CI3N807S.js
generated
vendored
84783
comfy/web/assets/index-CaD4RONs.js → comfy/web/assets/index-CI3N807S.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-CI3N807S.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-CI3N807S.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-CaD4RONs.js.map
generated
vendored
1
comfy/web/assets/index-CaD4RONs.js.map
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-DkvOTKox.js.map
generated
vendored
1
comfy/web/assets/index-DkvOTKox.js.map
generated
vendored
File diff suppressed because one or more lines are too long
637
comfy/web/assets/index-DAK31IJJ.css → comfy/web/assets/index-_5czGnTA.css
generated
vendored
637
comfy/web/assets/index-DAK31IJJ.css → comfy/web/assets/index-_5czGnTA.css
generated
vendored
File diff suppressed because it is too large
Load Diff
120
comfy/web/assets/userSelection-CyXKCVy3.js
generated
vendored
Normal file
120
comfy/web/assets/userSelection-CyXKCVy3.js
generated
vendored
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
var __defProp = Object.defineProperty;
|
||||||
|
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||||
|
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
|
||||||
|
class UserSelectionScreen {
|
||||||
|
static {
|
||||||
|
__name(this, "UserSelectionScreen");
|
||||||
|
}
|
||||||
|
async show(users, user) {
|
||||||
|
const userSelection = document.getElementById("comfy-user-selection");
|
||||||
|
userSelection.style.display = "";
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const input = userSelection.getElementsByTagName("input")[0];
|
||||||
|
const select = userSelection.getElementsByTagName("select")[0];
|
||||||
|
const inputSection = input.closest("section");
|
||||||
|
const selectSection = select.closest("section");
|
||||||
|
const form = userSelection.getElementsByTagName("form")[0];
|
||||||
|
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
|
||||||
|
const button = userSelection.getElementsByClassName(
|
||||||
|
"comfy-user-button-next"
|
||||||
|
)[0];
|
||||||
|
let inputActive = null;
|
||||||
|
input.addEventListener("focus", () => {
|
||||||
|
inputSection.classList.add("selected");
|
||||||
|
selectSection.classList.remove("selected");
|
||||||
|
inputActive = true;
|
||||||
|
});
|
||||||
|
select.addEventListener("focus", () => {
|
||||||
|
inputSection.classList.remove("selected");
|
||||||
|
selectSection.classList.add("selected");
|
||||||
|
inputActive = false;
|
||||||
|
select.style.color = "";
|
||||||
|
});
|
||||||
|
select.addEventListener("blur", () => {
|
||||||
|
if (!select.value) {
|
||||||
|
select.style.color = "var(--descrip-text)";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
form.addEventListener("submit", async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
if (inputActive == null) {
|
||||||
|
error.textContent = "Please enter a username or select an existing user.";
|
||||||
|
} else if (inputActive) {
|
||||||
|
const username = input.value.trim();
|
||||||
|
if (!username) {
|
||||||
|
error.textContent = "Please enter a username.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
input.disabled = select.disabled = // @ts-expect-error
|
||||||
|
input.readonly = // @ts-expect-error
|
||||||
|
select.readonly = true;
|
||||||
|
const spinner = createSpinner();
|
||||||
|
button.prepend(spinner);
|
||||||
|
try {
|
||||||
|
const resp = await api.createUser(username);
|
||||||
|
if (resp.status >= 300) {
|
||||||
|
let message = "Error creating user: " + resp.status + " " + resp.statusText;
|
||||||
|
try {
|
||||||
|
const res = await resp.json();
|
||||||
|
if (res.error) {
|
||||||
|
message = res.error;
|
||||||
|
}
|
||||||
|
} catch (error2) {
|
||||||
|
}
|
||||||
|
throw new Error(message);
|
||||||
|
}
|
||||||
|
resolve({ username, userId: await resp.json(), created: true });
|
||||||
|
} catch (err) {
|
||||||
|
spinner.remove();
|
||||||
|
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
|
||||||
|
input.disabled = select.disabled = // @ts-expect-error
|
||||||
|
input.readonly = // @ts-expect-error
|
||||||
|
select.readonly = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (!select.value) {
|
||||||
|
error.textContent = "Please select an existing user.";
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
resolve({
|
||||||
|
username: users[select.value],
|
||||||
|
userId: select.value,
|
||||||
|
created: false
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (user) {
|
||||||
|
const name = localStorage["Comfy.userName"];
|
||||||
|
if (name) {
|
||||||
|
input.value = name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (input.value) {
|
||||||
|
input.focus();
|
||||||
|
}
|
||||||
|
const userIds = Object.keys(users ?? {});
|
||||||
|
if (userIds.length) {
|
||||||
|
for (const u of userIds) {
|
||||||
|
$el("option", { textContent: users[u], value: u, parent: select });
|
||||||
|
}
|
||||||
|
select.style.color = "var(--descrip-text)";
|
||||||
|
if (select.value) {
|
||||||
|
select.focus();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
userSelection.classList.add("no-users");
|
||||||
|
input.focus();
|
||||||
|
}
|
||||||
|
}).then((r) => {
|
||||||
|
userSelection.remove();
|
||||||
|
return r;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.comfyAPI = window.comfyAPI || {};
|
||||||
|
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
|
||||||
|
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||||
|
export {
|
||||||
|
UserSelectionScreen
|
||||||
|
};
|
||||||
|
//# sourceMappingURL=userSelection-CyXKCVy3.js.map
|
||||||
1
comfy/web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
Normal file
1
comfy/web/assets/userSelection-CyXKCVy3.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
142
comfy/web/assets/userSelection-GRU1gtOt.js
generated
vendored
142
comfy/web/assets/userSelection-GRU1gtOt.js
generated
vendored
@ -1,142 +0,0 @@
|
|||||||
var __defProp = Object.defineProperty;
|
|
||||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
|
||||||
var __async = (__this, __arguments, generator) => {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
var fulfilled = (value) => {
|
|
||||||
try {
|
|
||||||
step(generator.next(value));
|
|
||||||
} catch (e) {
|
|
||||||
reject(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
var rejected = (value) => {
|
|
||||||
try {
|
|
||||||
step(generator.throw(value));
|
|
||||||
} catch (e) {
|
|
||||||
reject(e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
var step = (x) => x.done ? resolve(x.value) : Promise.resolve(x.value).then(fulfilled, rejected);
|
|
||||||
step((generator = generator.apply(__this, __arguments)).next());
|
|
||||||
});
|
|
||||||
};
|
|
||||||
import { j as createSpinner, g as api, $ as $el } from "./index-CaD4RONs.js";
|
|
||||||
const _UserSelectionScreen = class _UserSelectionScreen {
|
|
||||||
show(users, user) {
|
|
||||||
return __async(this, null, function* () {
|
|
||||||
const userSelection = document.getElementById("comfy-user-selection");
|
|
||||||
userSelection.style.display = "";
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
const input = userSelection.getElementsByTagName("input")[0];
|
|
||||||
const select = userSelection.getElementsByTagName("select")[0];
|
|
||||||
const inputSection = input.closest("section");
|
|
||||||
const selectSection = select.closest("section");
|
|
||||||
const form = userSelection.getElementsByTagName("form")[0];
|
|
||||||
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
|
|
||||||
const button = userSelection.getElementsByClassName(
|
|
||||||
"comfy-user-button-next"
|
|
||||||
)[0];
|
|
||||||
let inputActive = null;
|
|
||||||
input.addEventListener("focus", () => {
|
|
||||||
inputSection.classList.add("selected");
|
|
||||||
selectSection.classList.remove("selected");
|
|
||||||
inputActive = true;
|
|
||||||
});
|
|
||||||
select.addEventListener("focus", () => {
|
|
||||||
inputSection.classList.remove("selected");
|
|
||||||
selectSection.classList.add("selected");
|
|
||||||
inputActive = false;
|
|
||||||
select.style.color = "";
|
|
||||||
});
|
|
||||||
select.addEventListener("blur", () => {
|
|
||||||
if (!select.value) {
|
|
||||||
select.style.color = "var(--descrip-text)";
|
|
||||||
}
|
|
||||||
});
|
|
||||||
form.addEventListener("submit", (e) => __async(this, null, function* () {
|
|
||||||
var _a, _b, _c;
|
|
||||||
e.preventDefault();
|
|
||||||
if (inputActive == null) {
|
|
||||||
error.textContent = "Please enter a username or select an existing user.";
|
|
||||||
} else if (inputActive) {
|
|
||||||
const username = input.value.trim();
|
|
||||||
if (!username) {
|
|
||||||
error.textContent = "Please enter a username.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
input.disabled = select.disabled = // @ts-expect-error
|
|
||||||
input.readonly = // @ts-expect-error
|
|
||||||
select.readonly = true;
|
|
||||||
const spinner = createSpinner();
|
|
||||||
button.prepend(spinner);
|
|
||||||
try {
|
|
||||||
const resp = yield api.createUser(username);
|
|
||||||
if (resp.status >= 300) {
|
|
||||||
let message = "Error creating user: " + resp.status + " " + resp.statusText;
|
|
||||||
try {
|
|
||||||
const res = yield resp.json();
|
|
||||||
if (res.error) {
|
|
||||||
message = res.error;
|
|
||||||
}
|
|
||||||
} catch (error2) {
|
|
||||||
}
|
|
||||||
throw new Error(message);
|
|
||||||
}
|
|
||||||
resolve({ username, userId: yield resp.json(), created: true });
|
|
||||||
} catch (err) {
|
|
||||||
spinner.remove();
|
|
||||||
error.textContent = (_c = (_b = (_a = err.message) != null ? _a : err.statusText) != null ? _b : err) != null ? _c : "An unknown error occurred.";
|
|
||||||
input.disabled = select.disabled = // @ts-expect-error
|
|
||||||
input.readonly = // @ts-expect-error
|
|
||||||
select.readonly = false;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else if (!select.value) {
|
|
||||||
error.textContent = "Please select an existing user.";
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
resolve({
|
|
||||||
username: users[select.value],
|
|
||||||
userId: select.value,
|
|
||||||
created: false
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
if (user) {
|
|
||||||
const name = localStorage["Comfy.userName"];
|
|
||||||
if (name) {
|
|
||||||
input.value = name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (input.value) {
|
|
||||||
input.focus();
|
|
||||||
}
|
|
||||||
const userIds = Object.keys(users != null ? users : {});
|
|
||||||
if (userIds.length) {
|
|
||||||
for (const u of userIds) {
|
|
||||||
$el("option", { textContent: users[u], value: u, parent: select });
|
|
||||||
}
|
|
||||||
select.style.color = "var(--descrip-text)";
|
|
||||||
if (select.value) {
|
|
||||||
select.focus();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
userSelection.classList.add("no-users");
|
|
||||||
input.focus();
|
|
||||||
}
|
|
||||||
}).then((r) => {
|
|
||||||
userSelection.remove();
|
|
||||||
return r;
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
__name(_UserSelectionScreen, "UserSelectionScreen");
|
|
||||||
let UserSelectionScreen = _UserSelectionScreen;
|
|
||||||
window.comfyAPI = window.comfyAPI || {};
|
|
||||||
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
|
|
||||||
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
|
||||||
export {
|
|
||||||
UserSelectionScreen
|
|
||||||
};
|
|
||||||
//# sourceMappingURL=userSelection-GRU1gtOt.js.map
|
|
||||||
1
comfy/web/assets/userSelection-GRU1gtOt.js.map
generated
vendored
1
comfy/web/assets/userSelection-GRU1gtOt.js.map
generated
vendored
File diff suppressed because one or more lines are too long
4
comfy/web/index.html
vendored
4
comfy/web/index.html
vendored
@ -14,8 +14,8 @@
|
|||||||
</style> -->
|
</style> -->
|
||||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||||
<script type="module" crossorigin src="./assets/index-CaD4RONs.js"></script>
|
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
|
||||||
<link rel="stylesheet" crossorigin href="./assets/index-DAK31IJJ.css">
|
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
|
||||||
</head>
|
</head>
|
||||||
<body class="litegraph">
|
<body class="litegraph">
|
||||||
<div id="vue-app"></div>
|
<div id="vue-app"></div>
|
||||||
|
|||||||
4
comfy/web/materialdesignicons.min.css
vendored
4
comfy/web/materialdesignicons.min.css
vendored
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user