Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-09-03 15:28:52 -07:00
commit 38bcd9fcbd
25 changed files with 47494 additions and 41559 deletions

21
.github/workflows/stale-issues.yml vendored Normal file
View File

@ -0,0 +1,21 @@
name: 'Close stale issues'
on:
schedule:
# Run daily at 430 am PT
- cron: '30 11 * * *'
permissions:
issues: write
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v9
with:
stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
days-before-stale: 30
days-before-close: 7
stale-issue-label: 'Stale'
only-labels: 'User Support'
exempt-all-assignees: true
exempt-all-milestones: true

View File

@ -1,7 +1,11 @@
from aiohttp import web
import logging
from typing import Optional
from ....cmd.folder_paths import models_dir, user_directory, output_directory
from aiohttp import web
from ...services.file_service import FileService
from ....cmd.folder_paths import models_dir, user_directory, output_directory
class InternalRoutes:
'''
@ -10,6 +14,7 @@ class InternalRoutes:
Check README.md for more information.
'''
def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
@ -31,6 +36,10 @@ class InternalRoutes:
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
# todo: applications really shouldn't serve logs like this
return web.json_response({})
def get_app(self):
if self._app is None:

View File

@ -7,13 +7,13 @@ Use this instead of cli_args to import the args:
It will enable command line argument parsing. If this isn't desired, you must author your own implementation of these fixes.
"""
import ctypes
import importlib.util
import logging
import os
import shutil
import sys
import warnings
import ctypes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@ -110,6 +110,15 @@ def _create_tracer():
return trace.get_tracer(args.otel_service_name)
def _configure_logging():
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)
_configure_logging()
_fix_pytorch_240()
tracer = _create_tracer()
__all__ = ["args", "tracer"]

View File

@ -25,13 +25,11 @@ from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from ..api_server.routes.internal.internal_routes import InternalRoutes
from ..model_filemanager import download_model, DownloadModelStatus
from .latent_preview_image_encoding import encode_preview_image
from .. import interruption
from .. import model_management
from .. import node_helpers
from .. import utils
from ..api_server.routes.internal.internal_routes import InternalRoutes
from ..app.frontend_management import FrontendManager
from ..app.user_manager import UserManager
from ..cli_args import args
@ -45,6 +43,8 @@ from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTy
ExecutionStatus
from ..digest import digest
from ..images import open_image
from ..model_filemanager import download_model, DownloadModelStatus
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes
@ -60,6 +60,22 @@ async def send_socket_catch_exception(function, message):
logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@ -104,7 +120,7 @@ class PromptServer(ExecutorToClientProgress):
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
self.loop: AbstractEventLoop = loop
self.messages: asyncio.Queue = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
self.client_session: Optional[aiohttp.ClientSession] = None
self.number: int = 0
self.port: int = 8188
self._external_address: Optional[str] = None
@ -418,16 +434,20 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_system_stats(request):
device = model_management.get_torch_device()
device_name = model_management.get_torch_device_name(device)
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = model_management.get_free_memory(device, torch_free_too=True)
async def system_stats(request):
device = get_torch_device()
device_name = get_torch_device_name(device)
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
"pytorch_version": torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
"devices": [
{
@ -611,7 +631,7 @@ class PromptServer(ExecutorToClientProgress):
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress.
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
@ -776,7 +796,7 @@ class PromptServer(ExecutorToClientProgress):
self._external_address = value
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
def add_routes(self):

View File

@ -33,8 +33,6 @@ from .cldm import cldm, mmdit
from .ldm import hydit
from .ldm.cascade import controlnet as cascade_controlnet
from .ldm.flux import controlnet as controlnet_flux
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
from .t2i_adapter import adapter
@ -509,7 +507,12 @@ def load_controlnet_flux_instantx(sd):
for k in sd:
new_sd[k] = sd[k]
control_model = controlnet_flux.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.Flux()
@ -519,10 +522,6 @@ def load_controlnet_flux_instantx(sd):
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "controlnet_mode_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data:

View File

@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None):
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x
return sign.to(dtype=dtype)
return sign
@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
return output
return value.to(dtype=dtype)

View File

@ -1,20 +1,19 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
import math
from torch import Tensor, nn
from typing import Never
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .layers import (timestep_embedding)
from .model import Flux
from .. import common_dit
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
@ -29,6 +28,11 @@ class ControlNetFlux(Flux):
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
@ -61,6 +65,7 @@ class ControlNetFlux(Flux):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor | list[Never] | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -79,6 +84,11 @@ class ControlNetFlux(Flux):
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@ -137,4 +147,4 @@ class ControlNetFlux(Flux):
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

View File

@ -1,309 +0,0 @@
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import numbers
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils.import_utils import is_torch_version
from einops import rearrange, repeat
from torch import Tensor, nn
from ...ldm import common_dit
from .layers import timestep_embedding
from .model import Flux
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class FluxUnionControlNetModeEmbedder(nn.Module):
def __init__(self, num_mode, out_channels):
super().__init__()
self.mode_embber = nn.Embedding(num_mode, out_channels)
self.norm = nn.LayerNorm(out_channels, eps=1e-6)
self.fc = nn.Linear(out_channels, out_channels)
def forward(self, x):
x_emb = self.mode_embber(x)
x_emb = self.norm(x_emb)
x_emb = self.fc(x_emb)
x_emb = x_emb[:, 0]
return x_emb
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class FluxUnionControlNetInputEmbedder(nn.Module):
def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2):
super().__init__()
self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels))
self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6)
self.fc = nn.Linear(out_channels, out_channels)
self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels))
""" self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations
)
for i in range(2)
]
) """
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(depth)
]
)
self.out = zero_module(nn.Linear(out_channels, out_channels))
def forward(self, x, mode_emb):
mode_token = self.emb_embedder(mode_emb)[:, None]
x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb))
hidden_states = torch.cat([mode_token, x_emb], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=mode_emb,
)
hidden_states = self.out(hidden_states)
res = hidden_states[:, 1:]
return res
class InstantXControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
kwargs["depth"] = 0
kwargs["depth_single_blocks"] = 0
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(5)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(10)
]
)
self.require_vae = True
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_single_blocks.append(controlnet_block)
# TODO support both union and unimodal
self.union = True # num_mode is not None
num_mode = 10
if self.union:
self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype)
self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype)
self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype)
else:
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype)
self.gradient_checkpointing = False
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def set_hint_latents(self, hint_latents):
vae_shift_factor = 0.1159
vae_scaling_factor = 0.3611
num_channels_latents = self.in_channels // 4
hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor
height, width = hint_latents.shape[2:]
hint_latents = self._pack_latents(
hint_latents,
hint_latents.shape[0],
num_channels_latents,
height,
width,
)
self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
controlnet_mode=None
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
batch_size = img.shape[0]
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
if self.params.guidance_embed:
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
vec.add_(self.vector_in(y))
if self.union:
if controlnet_mode is None:
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device)
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
vec = vec + emb_controlnet_mode
img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode)
else:
img = img + self.controlnet_x_embedder(controlnet_cond)
txt = self.txt_in(txt)
if self.union:
token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None]
token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1)
txt = torch.cat([token_controlnet_mode, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device)
block_res_samples = ()
for block in self.transformer_blocks:
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
block_res_samples = block_res_samples + (img,)
img = torch.cat([txt, img], dim=1)
single_block_res_samples = ()
for block in self.single_transformer_blocks:
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_single_block_res_samples = ()
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
single_block_res_sample = single_controlnet_block(single_block_res_sample)
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
n_single_blocks = 38
n_double_blocks = 19
# Expand controlnet_block_res_samples to match n_double_blocks
expanded_controlnet_block_res_samples = []
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
for i in range(n_double_blocks):
index = i // interval_control_double
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
# Expand controlnet_single_block_res_samples to match n_single_blocks
expanded_controlnet_single_block_res_samples = []
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
for i in range(n_single_blocks):
index = i // interval_control_single
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
return {
"input": expanded_controlnet_block_res_samples,
"output": expanded_controlnet_single_block_res_samples
}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
height_control_image, width_control_image = hint.shape[2:]
num_channels_latents = self.in_channels // 4
hint = self._pack_latents(
hint,
hint.shape[0],
num_channels_latents,
height_control_image,
width_control_image,
)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)

View File

@ -1,227 +0,0 @@
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import numbers
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils.import_utils import is_torch_version
from einops import rearrange, repeat
from torch import Tensor, nn
from .layers import timestep_embedding
from .model import Flux
from ..common_dit import pad_to_patch_size
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class InstantXControlNetFluxFormat2(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
kwargs["depth"] = 0
kwargs["depth_single_blocks"] = 0
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(5)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(10)
]
)
self.require_vae = True
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_single_blocks.append(controlnet_block)
# TODO support both union and unimodal
self.union = True # num_mode is not None
num_mode = 10
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
self.gradient_checkpointing = False
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
controlnet_mode=None
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
batch_size = img.shape[0]
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
if self.params.guidance_embed:
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
vec.add_(self.vector_in(y))
txt = self.txt_in(txt)
if self.union:
if controlnet_mode is None:
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
controlnet_mode = controlnet_mode.reshape([-1, 1])
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
img = img + self.controlnet_x_embedder(controlnet_cond)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.transformer_blocks:
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
block_res_samples = block_res_samples + (img,)
img = torch.cat([txt, img], dim=1)
single_block_res_samples = ()
for block in self.single_transformer_blocks:
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_single_block_res_samples = ()
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
single_block_res_sample = single_controlnet_block(single_block_res_sample)
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
n_single_blocks = 38
n_double_blocks = 19
# Expand controlnet_block_res_samples to match n_double_blocks
expanded_controlnet_block_res_samples = []
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
for i in range(n_double_blocks):
index = i // interval_control_double
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
# Expand controlnet_single_block_res_samples to match n_single_blocks
expanded_controlnet_single_block_res_samples = []
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
for i in range(n_single_blocks):
index = i // interval_control_single
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
return {
"input": expanded_controlnet_block_res_samples,
"output": expanded_controlnet_single_block_res_samples
}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = pad_to_patch_size(x, (patch_size, patch_size))
height_control_image, width_control_image = hint.shape[2:]
num_channels_latents = self.in_channels // 4
hint = self._pack_latents(
hint,
hint.shape[0],
num_channels_latents,
height_control_image,
width_control_image,
)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)

View File

@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

View File

@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
return key_map
@ -528,20 +529,40 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True
if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:

View File

@ -61,6 +61,7 @@ cpu_state = CPUState.GPU
total_vram = 0
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
@ -419,13 +420,11 @@ def offloaded_memory(loaded_models, device):
return offloaded_mem
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
if any(platform.win32_ver()):
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 # Windows is higher because of the shared vram issue
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 # Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -436,6 +435,10 @@ def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
return _unload_model_clones(model, unload_weights_only, force_unload)
@ -1119,7 +1122,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
return True
if WINDOWS or manual_cast:
return True
else:
return False # weird linux behavior where fp32 is faster
if manual_cast:
free_model_memory = maximum_vram_for_weights(device)

File diff suppressed because it is too large Load Diff

1
comfy/web/assets/index-BD-Ia1C4.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-CI3N807S.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

120
comfy/web/assets/userSelection-CyXKCVy3.js generated vendored Normal file
View File

@ -0,0 +1,120 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-CI3N807S.js";
class UserSelectionScreen {
static {
__name(this, "UserSelectionScreen");
}
async show(users, user) {
const userSelection = document.getElementById("comfy-user-selection");
userSelection.style.display = "";
return new Promise((resolve) => {
const input = userSelection.getElementsByTagName("input")[0];
const select = userSelection.getElementsByTagName("select")[0];
const inputSection = input.closest("section");
const selectSection = select.closest("section");
const form = userSelection.getElementsByTagName("form")[0];
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
const button = userSelection.getElementsByClassName(
"comfy-user-button-next"
)[0];
let inputActive = null;
input.addEventListener("focus", () => {
inputSection.classList.add("selected");
selectSection.classList.remove("selected");
inputActive = true;
});
select.addEventListener("focus", () => {
inputSection.classList.remove("selected");
selectSection.classList.add("selected");
inputActive = false;
select.style.color = "";
});
select.addEventListener("blur", () => {
if (!select.value) {
select.style.color = "var(--descrip-text)";
}
});
form.addEventListener("submit", async (e) => {
e.preventDefault();
if (inputActive == null) {
error.textContent = "Please enter a username or select an existing user.";
} else if (inputActive) {
const username = input.value.trim();
if (!username) {
error.textContent = "Please enter a username.";
return;
}
input.disabled = select.disabled = // @ts-expect-error
input.readonly = // @ts-expect-error
select.readonly = true;
const spinner = createSpinner();
button.prepend(spinner);
try {
const resp = await api.createUser(username);
if (resp.status >= 300) {
let message = "Error creating user: " + resp.status + " " + resp.statusText;
try {
const res = await resp.json();
if (res.error) {
message = res.error;
}
} catch (error2) {
}
throw new Error(message);
}
resolve({ username, userId: await resp.json(), created: true });
} catch (err) {
spinner.remove();
error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred.";
input.disabled = select.disabled = // @ts-expect-error
input.readonly = // @ts-expect-error
select.readonly = false;
return;
}
} else if (!select.value) {
error.textContent = "Please select an existing user.";
return;
} else {
resolve({
username: users[select.value],
userId: select.value,
created: false
});
}
});
if (user) {
const name = localStorage["Comfy.userName"];
if (name) {
input.value = name;
}
}
if (input.value) {
input.focus();
}
const userIds = Object.keys(users ?? {});
if (userIds.length) {
for (const u of userIds) {
$el("option", { textContent: users[u], value: u, parent: select });
}
select.style.color = "var(--descrip-text)";
if (select.value) {
select.focus();
}
} else {
userSelection.classList.add("no-users");
input.focus();
}
}).then((r) => {
userSelection.remove();
return r;
});
}
}
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
//# sourceMappingURL=userSelection-CyXKCVy3.js.map

1
comfy/web/assets/userSelection-CyXKCVy3.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,142 +0,0 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
var __async = (__this, __arguments, generator) => {
return new Promise((resolve, reject) => {
var fulfilled = (value) => {
try {
step(generator.next(value));
} catch (e) {
reject(e);
}
};
var rejected = (value) => {
try {
step(generator.throw(value));
} catch (e) {
reject(e);
}
};
var step = (x) => x.done ? resolve(x.value) : Promise.resolve(x.value).then(fulfilled, rejected);
step((generator = generator.apply(__this, __arguments)).next());
});
};
import { j as createSpinner, g as api, $ as $el } from "./index-CaD4RONs.js";
const _UserSelectionScreen = class _UserSelectionScreen {
show(users, user) {
return __async(this, null, function* () {
const userSelection = document.getElementById("comfy-user-selection");
userSelection.style.display = "";
return new Promise((resolve) => {
const input = userSelection.getElementsByTagName("input")[0];
const select = userSelection.getElementsByTagName("select")[0];
const inputSection = input.closest("section");
const selectSection = select.closest("section");
const form = userSelection.getElementsByTagName("form")[0];
const error = userSelection.getElementsByClassName("comfy-user-error")[0];
const button = userSelection.getElementsByClassName(
"comfy-user-button-next"
)[0];
let inputActive = null;
input.addEventListener("focus", () => {
inputSection.classList.add("selected");
selectSection.classList.remove("selected");
inputActive = true;
});
select.addEventListener("focus", () => {
inputSection.classList.remove("selected");
selectSection.classList.add("selected");
inputActive = false;
select.style.color = "";
});
select.addEventListener("blur", () => {
if (!select.value) {
select.style.color = "var(--descrip-text)";
}
});
form.addEventListener("submit", (e) => __async(this, null, function* () {
var _a, _b, _c;
e.preventDefault();
if (inputActive == null) {
error.textContent = "Please enter a username or select an existing user.";
} else if (inputActive) {
const username = input.value.trim();
if (!username) {
error.textContent = "Please enter a username.";
return;
}
input.disabled = select.disabled = // @ts-expect-error
input.readonly = // @ts-expect-error
select.readonly = true;
const spinner = createSpinner();
button.prepend(spinner);
try {
const resp = yield api.createUser(username);
if (resp.status >= 300) {
let message = "Error creating user: " + resp.status + " " + resp.statusText;
try {
const res = yield resp.json();
if (res.error) {
message = res.error;
}
} catch (error2) {
}
throw new Error(message);
}
resolve({ username, userId: yield resp.json(), created: true });
} catch (err) {
spinner.remove();
error.textContent = (_c = (_b = (_a = err.message) != null ? _a : err.statusText) != null ? _b : err) != null ? _c : "An unknown error occurred.";
input.disabled = select.disabled = // @ts-expect-error
input.readonly = // @ts-expect-error
select.readonly = false;
return;
}
} else if (!select.value) {
error.textContent = "Please select an existing user.";
return;
} else {
resolve({
username: users[select.value],
userId: select.value,
created: false
});
}
}));
if (user) {
const name = localStorage["Comfy.userName"];
if (name) {
input.value = name;
}
}
if (input.value) {
input.focus();
}
const userIds = Object.keys(users != null ? users : {});
if (userIds.length) {
for (const u of userIds) {
$el("option", { textContent: users[u], value: u, parent: select });
}
select.style.color = "var(--descrip-text)";
if (select.value) {
select.focus();
}
} else {
userSelection.classList.add("no-users");
input.focus();
}
}).then((r) => {
userSelection.remove();
return r;
});
});
}
};
__name(_UserSelectionScreen, "UserSelectionScreen");
let UserSelectionScreen = _UserSelectionScreen;
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.userSelection = window.comfyAPI.userSelection || {};
window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
//# sourceMappingURL=userSelection-GRU1gtOt.js.map

File diff suppressed because one or more lines are too long

View File

@ -14,8 +14,8 @@
</style> -->
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-CaD4RONs.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-DAK31IJJ.css">
<script type="module" crossorigin src="./assets/index-CI3N807S.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-_5czGnTA.css">
</head>
<body class="litegraph">
<div id="vue-app"></div>

File diff suppressed because one or more lines are too long