diff --git a/comfy/app/frontend_management.py b/comfy/app/frontend_management.py index 57ba56d0e..aedf3dcd9 100644 --- a/comfy/app/frontend_management.py +++ b/comfy/app/frontend_management.py @@ -9,7 +9,7 @@ import zipfile from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict +from typing import TypedDict, Optional import requests from typing_extensions import NotRequired @@ -135,12 +135,13 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str) -> str: + def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: """ Initializes the frontend for the specified version. Args: version_string (str): The version string. + provider (FrontEndProvider, optional): The provider to use. Defaults to None. Returns: str: The path to the initialized frontend. @@ -153,7 +154,7 @@ class FrontendManager: return cls.DEFAULT_FRONTEND_PATH repo_owner, repo_name, version = cls.parse_version_string(version_string) - provider = FrontEndProvider(repo_owner, repo_name) + provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) semantic_version = release["tag_name"].lstrip("v") @@ -161,15 +162,21 @@ class FrontendManager: Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version ) if not os.path.exists(web_root): - os.makedirs(web_root, exist_ok=True) - logging.info( - "Downloading frontend(%s) version(%s) to (%s)", - provider.folder_name, - semantic_version, - web_root, - ) - logging.debug(release) - download_release_asset_zip(release, destination_path=web_root) + try: + os.makedirs(web_root, exist_ok=True) + logging.info( + "Downloading frontend(%s) version(%s) to (%s)", + provider.folder_name, + semantic_version, + web_root, + ) + logging.debug(release) + download_release_asset_zip(release, destination_path=web_root) + finally: + # Clean up the directory if it is empty, i.e. the download failed + if not os.listdir(web_root): + os.rmdir(web_root) + return web_root @classmethod diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index a9ddb9cd1..dbd29fd1f 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -603,7 +603,9 @@ class PromptServer(ExecutorToClientProgress): @routes.post("/internal/models/download") async def download_handler(request): async def report_progress(filename: str, status: DownloadModelStatus): - await self.send_json("download_progress", status.to_dict()) + payload = status.to_dict() + payload['download_path'] = filename + await self.send_json("download_progress", payload) data = await request.json() url = data.get('url') diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0959faf0a..255b52316 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -32,7 +32,7 @@ from . import utils from .cldm import cldm, mmdit from .ldm import hydit from .ldm.cascade import controlnet as cascade_controlnet -from .ldm.flux import controlnet_xlabs +from .ldm.flux import controlnet as controlnet_flux from .ldm.flux.controlnet_instantx import InstantXControlNetFlux from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2 from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES @@ -152,7 +152,7 @@ class ControlBase: elif self.strength_type == StrengthType.LINEAR_UP: x *= (self.strength ** float(len(control_output) - i)) - if x.dtype != output_dtype: + if output_dtype is not None and x.dtype != output_dtype: x = x.to(output_dtype) out[key].append(x) @@ -211,7 +211,6 @@ class ControlNet(ControlBase): if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype - output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint @@ -241,7 +240,7 @@ class ControlNet(ControlBase): x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) - return self.control_merge(control, control_prev, output_dtype) + return self.control_merge(control, control_prev, output_dtype=None) def copy(self): c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) @@ -441,7 +440,7 @@ def load_controlnet_hunyuandit(controlnet_data): return control -def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path): +def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full_path): keys_to_keep = [ "controlnet_", "single_transformer_blocks", @@ -498,19 +497,32 @@ def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path) def load_controlnet_flux_xlabs(sd): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) - control_model = controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_flux.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_flux_instantx(sd): + new_sd = model_detection.convert_diffusers_mmdit(sd, "") + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) + for k in sd: + new_sd[k] = sd[k] + + control_model = controlnet_flux.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, new_sd) + + latent_format = latent_formats.Flux() + extra_conds = ['y', 'guidance'] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) if "controlnet_mode_embedder.weight" in controlnet_data: - return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path) + return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path) if "controlnet_mode_embedder.fc.weight" in controlnet_data: - return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path) + return load_controlnet_flux_instantx_union(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path) if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data) if "lora_controlnet" in controlnet_data: @@ -572,8 +584,10 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: return load_controlnet_flux_xlabs(controlnet_data) - else: + elif "pos_embed_input.proj.weight" in controlnet_data: return load_controlnet_mmdit(controlnet_data) + elif "controlnet_x_embedder.weight" in controlnet_data: + return load_controlnet_flux_instantx(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 990025521..9016abc44 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,4 +1,5 @@ import torch +import comfy.ops def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): @@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + +try: + rms_norm_torch = torch.nn.functional.rms_norm +except: + rms_norm_torch = None + +def rms_norm(x, weight, eps=1e-6): + if rms_norm_torch is not None: + return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py new file mode 100644 index 000000000..2c658a4b1 --- /dev/null +++ b/comfy/ldm/flux/controlnet.py @@ -0,0 +1,140 @@ +#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py + +import torch +import math +from torch import Tensor, nn +from einops import rearrange, repeat + +from .layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + +from .model import Flux +import comfy.ldm.common_dit + + +class ControlNetFlux(Flux): + def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + + self.main_model_double = 19 + self.main_model_single = 38 + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(self.params.depth): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + self.controlnet_blocks.append(controlnet_block) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(self.params.depth_single_blocks): + self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + + self.gradient_checkpointing = False + self.latent_input = latent_input + self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + if not self.latent_input: + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + if not self.latent_input: + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + controlnet_double = () + + for i in range(len(self.double_blocks)): + img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe) + controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),) + + img = torch.cat((txt, img), 1) + + controlnet_single = () + + for i in range(len(self.single_blocks)): + img = self.single_blocks[i](img, vec=vec, pe=pe) + controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),) + + repeat = math.ceil(self.main_model_double / len(controlnet_double)) + if self.latent_input: + out_input = () + for x in controlnet_double: + out_input += (x,) * repeat + else: + out_input = (controlnet_double * repeat) + + out = {"input": out_input[:self.main_model_double]} + if len(controlnet_single) > 0: + repeat = math.ceil(self.main_model_single / len(controlnet_single)) + out_output = () + if self.latent_input: + for x in controlnet_single: + out_output += (x,) * repeat + else: + out_output = (controlnet_single * repeat) + out["output"] = out_output[:self.main_model_single] + return out + + def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + patch_size = 2 + if self.latent_input: + hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) + hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + else: + hint = hint * 2.0 - 1.0 + + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py deleted file mode 100644 index 5d700f16c..000000000 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ /dev/null @@ -1,104 +0,0 @@ -#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py - -import torch -from torch import Tensor, nn -from einops import rearrange, repeat - -from .layers import (DoubleStreamBlock, EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding) - -from .model import Flux -import comfy.ldm.common_dit - - -class ControlNetFlux(Flux): - def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): - super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) - - # add ControlNet blocks - self.controlnet_blocks = nn.ModuleList([]) - for _ in range(self.params.depth): - controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - # controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks.append(controlnet_block) - self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.gradient_checkpointing = False - self.input_hint_block = nn.Sequential( - operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) - ) - - def forward_orig( - self, - img: Tensor, - img_ids: Tensor, - controlnet_cond: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor = None, - ) -> Tensor: - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - - # running on sequences img - img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond - vec = self.time_in(timestep_embedding(timesteps, 256)) - if self.params.guidance_embed: - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - vec = vec + self.vector_in(y) - txt = self.txt_in(txt) - - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - - block_res_samples = () - - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe) - block_res_samples = block_res_samples + (img,) - - controlnet_block_res_samples = () - for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): - block_res_sample = controlnet_block(block_res_sample) - controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - - return {"input": (controlnet_block_res_samples * 10)[:19]} - - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): - hint = hint * 2.0 - 1.0 - - bs, c, h, w = x.shape - patch_size = 2 - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) - - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) - - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 10b87e517..0964086b5 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn from .math import attention, rope -from ... import ops +from ..common_dit import rms_norm class EmbedND(nn.Module): @@ -63,10 +63,7 @@ class RMSNorm(torch.nn.Module): self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): - x_dtype = x.dtype - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms).to(dtype=x_dtype) * ops.cast_to(self.scale, dtype=x_dtype, device=x.device) + return rms_norm(x, self.scale, 1e-6) class QKNorm(torch.nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 459892463..dc4ff87a8 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -356,29 +356,9 @@ class RMSNorm(torch.nn.Module): else: self.register_parameter("weight", None) - def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - """ - Forward pass through the RMSNorm layer. - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. - """ - x = self._norm(x) - if self.learnable_scale: - return x * self.weight.to(device=x.device, dtype=x.dtype) - else: - return x + return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) + class SwiGLUFeedForward(nn.Module): diff --git a/comfy/lora.py b/comfy/lora.py index 9687f9b65..c793ae1b9 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -15,12 +15,15 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import logging + import torch -from . import utils + from . import model_base from . import model_management +from . import utils LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -71,7 +74,7 @@ def load_lora(lora, to_load): B_name = "{}.lora.down.weight".format(x) elif transformers_lora in lora.keys(): A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) + B_name = "{}.lora_linear_layer.down.weight".format(x) if A_name is not None: mid = None @@ -82,7 +85,6 @@ def load_lora(lora, to_load): loaded_keys.add(A_name) loaded_keys.add(B_name) - ######## loha hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x) @@ -105,7 +107,6 @@ def load_lora(lora, to_load): loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_b_name) - ######## lokr lokr_w1_name = "{}.lokr_w1".format(x) lokr_w2_name = "{}.lokr_w2".format(x) @@ -153,7 +154,7 @@ def load_lora(lora, to_load): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - #glora + # glora a1_name = "{}.a1.weight".format(x) a2_name = "{}.a2.weight".format(x) b1_name = "{}.b1.weight".format(x) @@ -195,12 +196,13 @@ def load_lora(lora, to_load): return patch_dict + def model_lora_keys_clip(model, key_map={}): sdk = model.state_dict().keys() text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False - for b in range(32): #TODO: clean up + for b in range(32): # TODO: clean up for c in LORA_CLIP_MAP: k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: @@ -208,58 +210,58 @@ def model_lora_keys_clip(model, key_map={}): key_map[lora_key] = k lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora key_map[lora_key] = k k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k - lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base key_map[lora_key] = k clip_l_present = True - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora key_map[lora_key] = k k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: if clip_l_present: - lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base + lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # SDXL base key_map[lora_key] = k - lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora key_map[lora_key] = k else: - lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner + lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # TODO: test if this is correct for SDXL-Refiner key_map[lora_key] = k - lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora + lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) # diffusers lora key_map[lora_key] = k - lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config + lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) # cascade lora: TODO put lora key prefix in the model config key_map[lora_key] = k for k in sdk: if k.endswith(".weight"): - if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora + if k.startswith("t5xxl.transformer."): # OneTrainer SD3 lora l_key = k[len("t5xxl.transformer."):-len(".weight")] lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) key_map[lora_key] = k - elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora + elif k.startswith("hydit_clip.transformer.bert."): # HunyuanDiT Lora l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) key_map[lora_key] = k - k = "clip_g.transformer.text_projection.weight" if k in sdk: - key_map["lora_prior_te_text_projection"] = k #cascade lora? + key_map["lora_prior_te_text_projection"] = k # cascade lora? # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too - key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora + key_map["lora_te2_text_projection"] = k # OneTrainer SD3 lora k = "clip_l.transformer.text_projection.weight" if k in sdk: - key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning + key_map["lora_te1_text_projection"] = k # OneTrainer SD3 lora, not necessary but omits warning return key_map + def model_lora_keys_unet(model, key_map={}): sd = model.state_dict() sdk = sd.keys() @@ -268,8 +270,8 @@ def model_lora_keys_unet(model, key_map={}): if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_map["lora_unet_{}".format(key_lora)] = k - key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config - key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: TODO put lora key prefix in the model config + key_map["{}".format(k[:-len(".weight")])] = k # generic lora format without any weird key names diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config) for k in diffusers_keys: @@ -285,41 +287,41 @@ def model_lora_keys_unet(model, key_map={}): diffusers_lora_key = diffusers_lora_key[:-2] key_map[diffusers_lora_key] = unet_key - if isinstance(model, model_base.SD3): #Diffusers lora SD3 + if isinstance(model, model_base.SD3): # Diffusers lora SD3 diffusers_keys = utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: if k.endswith(".weight"): to = diffusers_keys[k] - key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format + key_lora = "transformer.{}".format(k[:-len(".weight")]) # regular diffusers sd3 lora format key_map[key_lora] = to - key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others? + key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # format for flash-sd3 lora and others? key_map[key_lora] = to - key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora + key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) # OneTrainer lora key_map[key_lora] = to - if isinstance(model, model_base.AuraFlow): #Diffusers lora AuraFlow + if isinstance(model, model_base.AuraFlow): # Diffusers lora AuraFlow diffusers_keys = utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: if k.endswith(".weight"): to = diffusers_keys[k] - key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format + key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format key_map[key_lora] = to if isinstance(model, model_base.HunyuanDiT): for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")] - key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format + key_map["base_model.model.{}".format(key_lora)] = k # official hunyuan lora format - if isinstance(model, model_base.Flux): #Diffusers lora Flux + if isinstance(model, model_base.Flux): # Diffusers lora Flux diffusers_keys = utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: if k.endswith(".weight"): to = diffusers_keys[k] - key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format - key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris + key_map["transformer.{}".format(k[:-len(".weight")])] = to # simpletrainer and probably regular diffusers flux lora format + key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to # simpletrainer lycoris return key_map @@ -344,6 +346,41 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat weight[:] = weight_calc return weight + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor + + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] @@ -363,7 +400,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): weight *= strength_model if isinstance(v, list): - v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), ) + v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype),) patch_type = "" if len(v) == 1: @@ -373,13 +410,19 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): v = v[1] if patch_type == "diff": - w1 = v[0] + diff: torch.Tensor = v[0] + # An extra flag to pad the weight if the diff's shape is larger than the weight + do_pad_weight = len(v) > 1 and v[1]['pad_weight'] + if do_pad_weight and diff.shape != weight.shape: + logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) + weight = pad_tensor_to_shape(weight, diff.shape) + if strength != 0.0: - if w1.shape != weight.shape: - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + if diff.shape != weight.shape: + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: - weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon + weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype)) + elif patch_type == "lora": # lora/locon mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype) dora_scale = v[4] @@ -389,7 +432,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): alpha = 1.0 if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it + # locon mid weights, hopefully the math is fine because I didn't properly test it mat3 = model_management.cast_to_device(v[3], weight.device, intermediate_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) @@ -415,7 +458,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): if w1 is None: dim = w1_b.shape[0] w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) else: w1 = model_management.cast_to_device(w1, weight.device, intermediate_dtype) @@ -423,12 +466,12 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): dim = w2_b.shape[0] if t2 is None: w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) else: w2 = torch.einsum('i j k l, j r, i p -> p r k l', - model_management.cast_to_device(t2, weight.device, intermediate_dtype), - model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + model_management.cast_to_device(t2, weight.device, intermediate_dtype), + model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) else: w2 = model_management.cast_to_device(w2, weight.device, intermediate_dtype) @@ -458,23 +501,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] m1 = torch.einsum('i j k l, j r, i p -> p r k l', - model_management.cast_to_device(t1, weight.device, intermediate_dtype), - model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + model_management.cast_to_device(t1, weight.device, intermediate_dtype), + model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) m2 = torch.einsum('i j k l, j r, i p -> p r k l', - model_management.cast_to_device(t2, weight.device, intermediate_dtype), - model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + model_management.cast_to_device(t2, weight.device, intermediate_dtype), + model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) else: m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) try: lora_diff = (m1 * m2).reshape(weight.shape) diff --git a/comfy/model_management.py b/comfy/model_management.py index f2abd6a97..eb7827c8f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -463,6 +463,8 @@ def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> if not force_unload: if unload_weights_only and unload_weight == False: return None + else: + unload_weight = True for i in to_unload: logging.debug("unload clone {} {}".format(i, unload_weight)) diff --git a/comfy/utils.py b/comfy/utils.py index f633e332e..d410c59f5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -566,6 +566,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + ("pos_embed_input.bias", "controlnet_x_embedder.bias"), + ("pos_embed_input.weight", "controlnet_x_embedder.weight"), } for k in MAP_BASIC: diff --git a/tests/unit/app_test/frontend_manager_test.py b/tests/unit/app_test/frontend_manager_test.py index dd44527b4..c9f292cac 100644 --- a/tests/unit/app_test/frontend_manager_test.py +++ b/tests/unit/app_test/frontend_manager_test.py @@ -2,6 +2,7 @@ import argparse import pytest from requests.exceptions import HTTPError +from unittest.mock import patch from comfy.app.frontend_management import ( FrontendManager, @@ -84,6 +85,35 @@ def test_init_frontend_invalid_provider(): with pytest.raises(HTTPError): FrontendManager.init_frontend_unsafe(version_string) +@pytest.fixture +def mock_os_functions(): + with patch('app.frontend_management.os.makedirs') as mock_makedirs, \ + patch('app.frontend_management.os.listdir') as mock_listdir, \ + patch('app.frontend_management.os.rmdir') as mock_rmdir: + mock_listdir.return_value = [] # Simulate empty directory + yield mock_makedirs, mock_listdir, mock_rmdir + +@pytest.fixture +def mock_download(): + with patch('app.frontend_management.download_release_asset_zip') as mock: + mock.side_effect = Exception("Download failed") # Simulate download failure + yield mock + +def test_finally_block(mock_os_functions, mock_download, mock_provider): + # Arrange + mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions + version_string = 'test-owner/test-repo@1.0.0' + + # Act & Assert + with pytest.raises(Exception): + FrontendManager.init_frontend_unsafe(version_string, mock_provider) + + # Assert + mock_makedirs.assert_called_once() + mock_download.assert_called_once() + mock_listdir.assert_called_once() + mock_rmdir.assert_called_once() + def test_parse_version_string(): version_string = "owner/repo@1.0.0"